mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 01:46:28 +03:00
Fix memory leak in beam parser
This commit is contained in:
parent
86ddf692a1
commit
2512ea9eeb
|
@ -9,36 +9,31 @@ from thinc.typedefs cimport hash_t, class_t
|
|||
from thinc.extra.search cimport MaxViolation
|
||||
|
||||
from .transition_system cimport TransitionSystem, Transition
|
||||
from .stateclass cimport StateClass
|
||||
from ..gold cimport GoldParse
|
||||
from .stateclass cimport StateC, StateClass
|
||||
|
||||
|
||||
# These are passed as callbacks to thinc.search.Beam
|
||||
cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1:
|
||||
dest = <StateClass>_dest
|
||||
src = <StateClass>_src
|
||||
dest = <StateC*>_dest
|
||||
src = <StateC*>_src
|
||||
moves = <const Transition*>_moves
|
||||
dest.clone(src)
|
||||
moves[clas].do(dest.c, moves[clas].label)
|
||||
dest.c.push_hist(clas)
|
||||
moves[clas].do(dest, moves[clas].label)
|
||||
dest.push_hist(clas)
|
||||
|
||||
|
||||
cdef int _check_final_state(void* _state, void* extra_args) except -1:
|
||||
return (<StateClass>_state).is_final()
|
||||
|
||||
|
||||
def _cleanup(Beam beam):
|
||||
for i in range(beam.width):
|
||||
Py_XDECREF(<PyObject*>beam._states[i].content)
|
||||
Py_XDECREF(<PyObject*>beam._parents[i].content)
|
||||
state = <StateC*>_state
|
||||
return state.is_final()
|
||||
|
||||
|
||||
cdef hash_t _hash_state(void* _state, void* _) except 0:
|
||||
state = <StateClass>_state
|
||||
if state.c.is_final():
|
||||
state = <StateC*>_state
|
||||
if state.is_final():
|
||||
return 1
|
||||
else:
|
||||
return state.c.hash()
|
||||
return state.hash()
|
||||
|
||||
|
||||
cdef class ParserBeam(object):
|
||||
|
@ -55,14 +50,15 @@ cdef class ParserBeam(object):
|
|||
self.golds = golds
|
||||
self.beams = []
|
||||
cdef Beam beam
|
||||
cdef StateClass state, st
|
||||
cdef StateClass state
|
||||
cdef StateC* st
|
||||
for state in states:
|
||||
beam = Beam(self.moves.n_moves, width, density)
|
||||
beam.initialize(self.moves.init_beam_state, state.c.length,
|
||||
state.c._sent)
|
||||
for i in range(beam.width):
|
||||
st = <StateClass>beam.at(i)
|
||||
st.c.offset = state.c.offset
|
||||
st = <StateC*>beam.at(i)
|
||||
st.offset = state.c.offset
|
||||
self.beams.append(beam)
|
||||
self.dones = [False] * len(self.beams)
|
||||
|
||||
|
@ -86,13 +82,14 @@ cdef class ParserBeam(object):
|
|||
if self.golds is not None:
|
||||
self._set_costs(beam, self.golds[i], follow_gold=follow_gold)
|
||||
if follow_gold:
|
||||
beam.advance(_transition_state, NULL, <void*>self.moves.c)
|
||||
beam.advance(_transition_state, _hash_state, <void*>self.moves.c)
|
||||
else:
|
||||
beam.advance(_transition_state, _hash_state, <void*>self.moves.c)
|
||||
beam.check_done(_check_final_state, NULL)
|
||||
# This handles the non-monotonic stuff for the parser.
|
||||
if beam.is_done and self.golds is not None:
|
||||
for j in range(beam.size):
|
||||
state = <StateClass>beam.at(j)
|
||||
state = StateClass.borrow(<StateC*>beam.at(j))
|
||||
if state.is_final():
|
||||
try:
|
||||
if self.moves.is_gold_parse(state, self.golds[i]):
|
||||
|
@ -107,11 +104,11 @@ cdef class ParserBeam(object):
|
|||
cdef int nr_state = min(scores.shape[0], beam.size)
|
||||
cdef int nr_class = scores.shape[1]
|
||||
for i in range(nr_state):
|
||||
state = <StateClass>beam.at(i)
|
||||
state = <StateC*>beam.at(i)
|
||||
if not state.is_final():
|
||||
for j in range(nr_class):
|
||||
beam.scores[i][j] = c_scores[i * nr_class + j]
|
||||
self.moves.set_valid(beam.is_valid[i], state.c)
|
||||
self.moves.set_valid(beam.is_valid[i], state)
|
||||
else:
|
||||
for j in range(beam.nr_class):
|
||||
beam.scores[i][j] = 0
|
||||
|
@ -119,8 +116,8 @@ cdef class ParserBeam(object):
|
|||
|
||||
def _set_costs(self, Beam beam, GoldParse gold, int follow_gold=False):
|
||||
for i in range(beam.size):
|
||||
state = <StateClass>beam.at(i)
|
||||
if not state.c.is_final():
|
||||
state = StateClass.borrow(<StateC*>beam.at(i))
|
||||
if not state.is_final():
|
||||
self.moves.set_costs(beam.is_valid[i], beam.costs[i],
|
||||
state, gold)
|
||||
if follow_gold:
|
||||
|
@ -157,7 +154,7 @@ def update_beam(TransitionSystem moves, int nr_feature, int max_steps,
|
|||
pbeam = ParserBeam(moves, states, golds,
|
||||
width=width, density=density)
|
||||
gbeam = ParserBeam(moves, states, golds,
|
||||
width=width, density=0.0)
|
||||
width=width, density=density)
|
||||
cdef StateClass state
|
||||
beam_maps = []
|
||||
backprops = []
|
||||
|
@ -231,7 +228,7 @@ def get_states(pbeams, gbeams, beam_map, nr_update):
|
|||
p_indices.append([])
|
||||
g_indices.append([])
|
||||
for i in range(pbeam.size):
|
||||
state = <StateClass>pbeam.at(i)
|
||||
state = StateClass.borrow(<StateC*>pbeam.at(i))
|
||||
if not state.is_final():
|
||||
key = tuple([eg_id] + pbeam.histories[i])
|
||||
assert key not in seen, (key, seen)
|
||||
|
@ -240,7 +237,7 @@ def get_states(pbeams, gbeams, beam_map, nr_update):
|
|||
states.append(state)
|
||||
beam_map.update(seen)
|
||||
for i in range(gbeam.size):
|
||||
state = <StateClass>gbeam.at(i)
|
||||
state = StateClass.borrow(<StateC*>gbeam.at(i))
|
||||
if not state.is_final():
|
||||
key = tuple([eg_id] + gbeam.histories[i])
|
||||
if key in seen:
|
||||
|
|
|
@ -292,12 +292,16 @@ cdef int _get_root(int word, const GoldParseC* gold) nogil:
|
|||
|
||||
|
||||
cdef void* _init_state(Pool mem, int length, void* tokens) except NULL:
|
||||
cdef StateClass st = StateClass.init(<const TokenC*>tokens, length)
|
||||
for i in range(st.c.length):
|
||||
st.c._sent[i].l_edge = i
|
||||
st.c._sent[i].r_edge = i
|
||||
st = new StateC(<const TokenC*>tokens, length)
|
||||
for i in range(st.length):
|
||||
if st._sent[i].dep == 0:
|
||||
st._sent[i].l_edge = i
|
||||
st._sent[i].r_edge = i
|
||||
st._sent[i].head = 0
|
||||
st._sent[i].dep = 0
|
||||
st._sent[i].l_kids = 0
|
||||
st._sent[i].r_kids = 0
|
||||
st.fast_forward()
|
||||
Py_INCREF(st)
|
||||
return <void*>st
|
||||
|
||||
|
||||
|
@ -533,18 +537,18 @@ cdef class ArcEager(TransitionSystem):
|
|||
assert n_gold >= 1
|
||||
|
||||
def get_beam_annot(self, Beam beam):
|
||||
length = (<StateClass>beam.at(0)).c.length
|
||||
length = (<StateC*>beam.at(0)).length
|
||||
heads = [{} for _ in range(length)]
|
||||
deps = [{} for _ in range(length)]
|
||||
probs = beam.probs
|
||||
for i in range(beam.size):
|
||||
stcls = <StateClass>beam.at(i)
|
||||
self.finalize_state(stcls.c)
|
||||
if stcls.is_final():
|
||||
state = <StateC*>beam.at(i)
|
||||
self.finalize_state(state)
|
||||
if state.is_final():
|
||||
prob = probs[i]
|
||||
for j in range(stcls.c.length):
|
||||
head = j + stcls.c._sent[j].head
|
||||
dep = stcls.c._sent[j].dep
|
||||
for j in range(state.length):
|
||||
head = j + state._sent[j].head
|
||||
dep = state._sent[j].dep
|
||||
heads[j].setdefault(head, 0.0)
|
||||
heads[j][head] += prob
|
||||
deps[j].setdefault(dep, 0.0)
|
||||
|
|
|
@ -123,14 +123,14 @@ cdef class BiluoPushDown(TransitionSystem):
|
|||
entities = {}
|
||||
probs = beam.probs
|
||||
for i in range(beam.size):
|
||||
stcls = <StateClass>beam.at(i)
|
||||
if stcls.is_final():
|
||||
self.finalize_state(stcls.c)
|
||||
state = <StateC*>beam.at(i)
|
||||
if state.is_final():
|
||||
self.finalize_state(state)
|
||||
prob = probs[i]
|
||||
for j in range(stcls.c._e_i):
|
||||
start = stcls.c._ents[j].start
|
||||
end = stcls.c._ents[j].end
|
||||
label = stcls.c._ents[j].label
|
||||
for j in range(state._e_i):
|
||||
start = state._ents[j].start
|
||||
end = state._ents[j].end
|
||||
label = state._ents[j].label
|
||||
entities.setdefault((start, end, label), 0.0)
|
||||
entities[(start, end, label)] += prob
|
||||
return entities
|
||||
|
@ -139,15 +139,15 @@ cdef class BiluoPushDown(TransitionSystem):
|
|||
parses = []
|
||||
probs = beam.probs
|
||||
for i in range(beam.size):
|
||||
stcls = <StateClass>beam.at(i)
|
||||
if stcls.is_final():
|
||||
self.finalize_state(stcls.c)
|
||||
state = <StateC*>beam.at(i)
|
||||
if state.is_final():
|
||||
self.finalize_state(state)
|
||||
prob = probs[i]
|
||||
parse = []
|
||||
for j in range(stcls.c._e_i):
|
||||
start = stcls.c._ents[j].start
|
||||
end = stcls.c._ents[j].end
|
||||
label = stcls.c._ents[j].label
|
||||
for j in range(state._e_i):
|
||||
start = state._ents[j].start
|
||||
end = state._ents[j].end
|
||||
label = state._ents[j].label
|
||||
parse.append((start, end, self.strings[label]))
|
||||
parses.append((prob, parse))
|
||||
return parses
|
||||
|
|
|
@ -224,6 +224,16 @@ cdef void cpu_regression_loss(float* d_scores,
|
|||
d_scores[i] = diff
|
||||
|
||||
|
||||
def _collect_states(beams):
|
||||
cdef StateClass state
|
||||
cdef Beam beam
|
||||
states = []
|
||||
for beam in beams:
|
||||
state = StateClass.borrow(<StateC*>beam.at(0))
|
||||
states.append(state)
|
||||
return states
|
||||
|
||||
|
||||
cdef class Parser:
|
||||
"""
|
||||
Base class of the DependencyParser and EntityRecognizer.
|
||||
|
@ -336,7 +346,7 @@ cdef class Parser:
|
|||
beam_density=beam_density)
|
||||
beam = beams[0]
|
||||
output = self.moves.get_beam_annot(beam)
|
||||
state = <StateClass>beam.at(0)
|
||||
state = StateClass.borrow(<StateC*>beam.at(0))
|
||||
self.set_annotations([doc], [state], tensors=tokvecs)
|
||||
_cleanup(beam)
|
||||
return output
|
||||
|
@ -356,10 +366,10 @@ cdef class Parser:
|
|||
if beam_density is None:
|
||||
beam_density = self.cfg.get('beam_density', 0.0)
|
||||
cdef Doc doc
|
||||
cdef Beam beam
|
||||
for batch in cytoolz.partition_all(batch_size, docs):
|
||||
batch = list(batch)
|
||||
by_length = sorted(list(batch), key=lambda doc: len(doc))
|
||||
batch_in_order = list(batch)
|
||||
by_length = sorted(batch_in_order, key=lambda doc: len(doc))
|
||||
batch_beams = []
|
||||
for subbatch in cytoolz.partition_all(8, by_length):
|
||||
subbatch = list(subbatch)
|
||||
if beam_width == 1:
|
||||
|
@ -369,21 +379,20 @@ cdef class Parser:
|
|||
beams, tokvecs = self.beam_parse(subbatch,
|
||||
beam_width=beam_width,
|
||||
beam_density=beam_density)
|
||||
parse_states = []
|
||||
for beam in beams:
|
||||
parse_states.append(<StateClass>beam.at(0))
|
||||
self.set_annotations(subbatch, parse_states, tensors=tokvecs)
|
||||
yield from batch
|
||||
parse_states = _collect_states(beams)
|
||||
self.set_annotations(subbatch, parse_states, tensors=None)
|
||||
for beam in beams:
|
||||
_cleanup(beam)
|
||||
for doc in batch_in_order:
|
||||
yield doc
|
||||
|
||||
def parse_batch(self, docs):
|
||||
cdef:
|
||||
precompute_hiddens state2vec
|
||||
StateClass stcls
|
||||
Pool mem
|
||||
const float* feat_weights
|
||||
StateC* st
|
||||
StateClass stcls
|
||||
vector[StateC*] states
|
||||
int guess, nr_class, nr_feat, nr_piece, nr_dim, nr_state, nr_step
|
||||
int j
|
||||
|
@ -488,14 +497,14 @@ cdef class Parser:
|
|||
beam = Beam(nr_class, beam_width, min_density=beam_density)
|
||||
beam.initialize(self.moves.init_beam_state, doc.length, doc.c)
|
||||
for i in range(beam.width):
|
||||
stcls = <StateClass>beam.at(i)
|
||||
stcls.c.offset = offset
|
||||
state = <StateC*>beam.at(i)
|
||||
state.offset = offset
|
||||
offset += len(doc)
|
||||
beam.check_done(_check_final_state, NULL)
|
||||
while not beam.is_done:
|
||||
states = []
|
||||
for i in range(beam.size):
|
||||
stcls = <StateClass>beam.at(i)
|
||||
stcls = StateClass.borrow(<StateC*>beam.at(i))
|
||||
# This way we avoid having to score finalized states
|
||||
# We do have to take care to keep indexes aligned, though
|
||||
if not stcls.is_final():
|
||||
|
@ -511,9 +520,9 @@ cdef class Parser:
|
|||
j = 0
|
||||
c_scores = <float*>scores.data
|
||||
for i in range(beam.size):
|
||||
stcls = <StateClass>beam.at(i)
|
||||
if not stcls.is_final():
|
||||
self.moves.set_valid(beam.is_valid[i], stcls.c)
|
||||
state = <StateC*>beam.at(i)
|
||||
if not state.is_final():
|
||||
self.moves.set_valid(beam.is_valid[i], state)
|
||||
for k in range(nr_class):
|
||||
beam.scores[i][k] = c_scores[j * scores.shape[1] + k]
|
||||
j += 1
|
||||
|
@ -965,27 +974,40 @@ cdef int arg_max_if_valid(const weight_t* scores, const int* is_valid, int n) no
|
|||
|
||||
# These are passed as callbacks to thinc.search.Beam
|
||||
cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1:
|
||||
dest = <StateClass>_dest
|
||||
src = <StateClass>_src
|
||||
dest = <StateC*>_dest
|
||||
src = <StateC*>_src
|
||||
moves = <const Transition*>_moves
|
||||
dest.clone(src)
|
||||
moves[clas].do(dest.c, moves[clas].label)
|
||||
dest.c.push_hist(clas)
|
||||
moves[clas].do(dest, moves[clas].label)
|
||||
dest.push_hist(clas)
|
||||
|
||||
|
||||
cdef int _check_final_state(void* _state, void* extra_args) except -1:
|
||||
return (<StateClass>_state).is_final()
|
||||
state = <StateC*>_state
|
||||
return state.is_final()
|
||||
|
||||
|
||||
def _cleanup(Beam beam):
|
||||
cdef StateC* state
|
||||
# Once parsing has finished, states in beam may not be unique. Is this
|
||||
# correct?
|
||||
seen = set()
|
||||
for i in range(beam.width):
|
||||
Py_XDECREF(<PyObject*>beam._states[i].content)
|
||||
Py_XDECREF(<PyObject*>beam._parents[i].content)
|
||||
addr = <size_t>beam._parents[i].content
|
||||
if addr not in seen:
|
||||
state = <StateC*>addr
|
||||
del state
|
||||
seen.add(addr)
|
||||
addr = <size_t>beam._states[i].content
|
||||
if addr not in seen:
|
||||
state = <StateC*>addr
|
||||
del state
|
||||
seen.add(addr)
|
||||
|
||||
|
||||
cdef hash_t _hash_state(void* _state, void* _) except 0:
|
||||
state = <StateClass>_state
|
||||
if state.c.is_final():
|
||||
state = <StateC*>_state
|
||||
if state.is_final():
|
||||
return 1
|
||||
else:
|
||||
return state.c.hash()
|
||||
return state.hash()
|
||||
|
|
|
@ -13,6 +13,7 @@ from ._state cimport StateC
|
|||
cdef class StateClass:
|
||||
cdef Pool mem
|
||||
cdef StateC* c
|
||||
cdef int _borrowed
|
||||
|
||||
@staticmethod
|
||||
cdef inline StateClass init(const TokenC* sent, int length):
|
||||
|
@ -20,6 +21,15 @@ cdef class StateClass:
|
|||
self.c = new StateC(sent, length)
|
||||
return self
|
||||
|
||||
@staticmethod
|
||||
cdef inline StateClass borrow(StateC* ptr):
|
||||
cdef StateClass self = StateClass()
|
||||
del self.c
|
||||
self.c = ptr
|
||||
self._borrowed = 1
|
||||
return self
|
||||
|
||||
|
||||
@staticmethod
|
||||
cdef inline StateClass init_offset(const TokenC* sent, int length, int
|
||||
offset):
|
||||
|
|
|
@ -11,11 +11,13 @@ cdef class StateClass:
|
|||
def __init__(self, Doc doc=None, int offset=0):
|
||||
cdef Pool mem = Pool()
|
||||
self.mem = mem
|
||||
self._borrowed = 0
|
||||
if doc is not None:
|
||||
self.c = new StateC(doc.c, doc.length)
|
||||
self.c.offset = offset
|
||||
|
||||
def __dealloc__(self):
|
||||
if self._borrowed != 1:
|
||||
del self.c
|
||||
|
||||
@property
|
||||
|
|
|
@ -23,8 +23,7 @@ class OracleError(Exception):
|
|||
|
||||
|
||||
cdef void* _init_state(Pool mem, int length, void* tokens) except NULL:
|
||||
cdef StateClass st = StateClass.init(<const TokenC*>tokens, length)
|
||||
Py_INCREF(st)
|
||||
cdef StateC* st = new StateC(<const TokenC*>tokens, length)
|
||||
return <void*>st
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user