mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-28 17:40:39 +03:00
Restore beam_density argument for parser beam
This commit is contained in:
parent
625ee6c464
commit
7b9195657b
|
@ -68,7 +68,7 @@ cdef class ParserBeam(object):
|
||||||
cdef StateClass state
|
cdef StateClass state
|
||||||
cdef StateC* st
|
cdef StateC* st
|
||||||
for state in states:
|
for state in states:
|
||||||
beam = Beam(self.moves.n_moves, width, density)
|
beam = Beam(self.moves.n_moves, width, min_density=density)
|
||||||
beam.initialize(self.moves.init_beam_state, state.c.length,
|
beam.initialize(self.moves.init_beam_state, state.c.length,
|
||||||
state.c._sent)
|
state.c._sent)
|
||||||
for i in range(beam.width):
|
for i in range(beam.width):
|
||||||
|
@ -161,12 +161,12 @@ def update_beam(TransitionSystem moves, int nr_feature, int max_steps,
|
||||||
states, golds,
|
states, golds,
|
||||||
state2vec, vec2scores,
|
state2vec, vec2scores,
|
||||||
int width, losses=None, drop=0.,
|
int width, losses=None, drop=0.,
|
||||||
early_update=True):
|
early_update=True, beam_density=0.0):
|
||||||
global nr_update
|
global nr_update
|
||||||
cdef MaxViolation violn
|
cdef MaxViolation violn
|
||||||
nr_update += 1
|
nr_update += 1
|
||||||
pbeam = ParserBeam(moves, states, golds, width=width)
|
pbeam = ParserBeam(moves, states, golds, width=width, density=beam_density)
|
||||||
gbeam = ParserBeam(moves, states, golds, width=width)
|
gbeam = ParserBeam(moves, states, golds, width=width, density=beam_density)
|
||||||
cdef StateClass state
|
cdef StateClass state
|
||||||
beam_maps = []
|
beam_maps = []
|
||||||
backprops = []
|
backprops = []
|
||||||
|
|
|
@ -182,7 +182,9 @@ cdef class Parser:
|
||||||
"""
|
"""
|
||||||
if beam_width is None:
|
if beam_width is None:
|
||||||
beam_width = self.cfg.get('beam_width', 1)
|
beam_width = self.cfg.get('beam_width', 1)
|
||||||
states = self.predict([doc], beam_width=beam_width)
|
beam_density = self.cfg.get('beam_density', 0.)
|
||||||
|
states = self.predict([doc], beam_width=beam_width,
|
||||||
|
beam_density=beam_density)
|
||||||
self.set_annotations([doc], states, tensors=None)
|
self.set_annotations([doc], states, tensors=None)
|
||||||
return doc
|
return doc
|
||||||
|
|
||||||
|
@ -197,24 +199,27 @@ cdef class Parser:
|
||||||
"""
|
"""
|
||||||
if beam_width is None:
|
if beam_width is None:
|
||||||
beam_width = self.cfg.get('beam_width', 1)
|
beam_width = self.cfg.get('beam_width', 1)
|
||||||
|
beam_density = self.cfg.get('beam_density', 0.)
|
||||||
cdef Doc doc
|
cdef Doc doc
|
||||||
for batch in cytoolz.partition_all(batch_size, docs):
|
for batch in cytoolz.partition_all(batch_size, docs):
|
||||||
batch_in_order = list(batch)
|
batch_in_order = list(batch)
|
||||||
by_length = sorted(batch_in_order, key=lambda doc: len(doc))
|
by_length = sorted(batch_in_order, key=lambda doc: len(doc))
|
||||||
for subbatch in cytoolz.partition_all(8, by_length):
|
for subbatch in cytoolz.partition_all(8, by_length):
|
||||||
subbatch = list(subbatch)
|
subbatch = list(subbatch)
|
||||||
parse_states = self.predict(subbatch, beam_width=beam_width)
|
parse_states = self.predict(subbatch, beam_width=beam_width,
|
||||||
|
beam_density=beam_density)
|
||||||
self.set_annotations(subbatch, parse_states, tensors=None)
|
self.set_annotations(subbatch, parse_states, tensors=None)
|
||||||
for doc in batch_in_order:
|
for doc in batch_in_order:
|
||||||
yield doc
|
yield doc
|
||||||
|
|
||||||
def predict(self, docs, beam_width=1, drop=0.):
|
def predict(self, docs, beam_width=1, beam_density=0.0, drop=0.):
|
||||||
if isinstance(docs, Doc):
|
if isinstance(docs, Doc):
|
||||||
docs = [docs]
|
docs = [docs]
|
||||||
if beam_width < 2:
|
if beam_width < 2:
|
||||||
return self.greedy_parse(docs, drop=drop)
|
return self.greedy_parse(docs, drop=drop)
|
||||||
else:
|
else:
|
||||||
return self.beam_parse(docs, beam_width=beam_width, drop=drop)
|
return self.beam_parse(docs, beam_width=beam_width,
|
||||||
|
beam_density=beam_density, drop=drop)
|
||||||
|
|
||||||
def greedy_parse(self, docs, drop=0.):
|
def greedy_parse(self, docs, drop=0.):
|
||||||
cdef vector[StateC*] states
|
cdef vector[StateC*] states
|
||||||
|
@ -231,12 +236,12 @@ cdef class Parser:
|
||||||
weights, sizes)
|
weights, sizes)
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
def beam_parse(self, docs, int beam_width, float drop=0.):
|
def beam_parse(self, docs, int beam_width, float drop=0., beam_density=0.):
|
||||||
cdef Beam beam
|
cdef Beam beam
|
||||||
cdef Doc doc
|
cdef Doc doc
|
||||||
cdef np.ndarray token_ids
|
cdef np.ndarray token_ids
|
||||||
model = self.model(docs)
|
model = self.model(docs)
|
||||||
beams = self.moves.init_beams(docs, beam_width)
|
beams = self.moves.init_beams(docs, beam_width, beam_density=beam_density)
|
||||||
token_ids = numpy.zeros((len(docs) * beam_width, self.nr_feature),
|
token_ids = numpy.zeros((len(docs) * beam_width, self.nr_feature),
|
||||||
dtype='i', order='C')
|
dtype='i', order='C')
|
||||||
cdef int* c_ids
|
cdef int* c_ids
|
||||||
|
@ -358,9 +363,9 @@ cdef class Parser:
|
||||||
# a greedy update
|
# a greedy update
|
||||||
beam_update_prob = self.cfg.get('beam_update_prob', 1.0)
|
beam_update_prob = self.cfg.get('beam_update_prob', 1.0)
|
||||||
if self.cfg.get('beam_width', 1) >= 2 and numpy.random.random() < beam_update_prob:
|
if self.cfg.get('beam_width', 1) >= 2 and numpy.random.random() < beam_update_prob:
|
||||||
return self.update_beam(docs, golds,
|
return self.update_beam(docs, golds, self.cfg.get('beam_width', 1),
|
||||||
self.cfg['beam_width'],
|
drop=drop, sgd=sgd, losses=losses,
|
||||||
drop=drop, sgd=sgd, losses=losses)
|
beam_density=self.cfg.get('beam_density', 0.0))
|
||||||
# Chop sequences into lengths of this many transitions, to make the
|
# Chop sequences into lengths of this many transitions, to make the
|
||||||
# batch uniform length.
|
# batch uniform length.
|
||||||
cut_gold = numpy.random.choice(range(20, 100))
|
cut_gold = numpy.random.choice(range(20, 100))
|
||||||
|
@ -384,7 +389,8 @@ cdef class Parser:
|
||||||
finish_update(golds, sgd=sgd)
|
finish_update(golds, sgd=sgd)
|
||||||
return losses
|
return losses
|
||||||
|
|
||||||
def update_beam(self, docs, golds, width, drop=0., sgd=None, losses=None):
|
def update_beam(self, docs, golds, width, drop=0., sgd=None, losses=None,
|
||||||
|
beam_density=0.0):
|
||||||
lengths = [len(d) for d in docs]
|
lengths = [len(d) for d in docs]
|
||||||
states = self.moves.init_batch(docs)
|
states = self.moves.init_batch(docs)
|
||||||
for gold in golds:
|
for gold in golds:
|
||||||
|
@ -392,7 +398,8 @@ cdef class Parser:
|
||||||
model, finish_update = self.model.begin_update(docs, drop=drop)
|
model, finish_update = self.model.begin_update(docs, drop=drop)
|
||||||
states_d_scores, backprops, beams = _beam_utils.update_beam(
|
states_d_scores, backprops, beams = _beam_utils.update_beam(
|
||||||
self.moves, self.nr_feature, 10000, states, golds, model.state2vec,
|
self.moves, self.nr_feature, 10000, states, golds, model.state2vec,
|
||||||
model.vec2scores, width, drop=drop, losses=losses)
|
model.vec2scores, width, drop=drop, losses=losses,
|
||||||
|
beam_density=beam_density)
|
||||||
for i, d_scores in enumerate(states_d_scores):
|
for i, d_scores in enumerate(states_d_scores):
|
||||||
losses[self.name] += (d_scores**2).sum()
|
losses[self.name] += (d_scores**2).sum()
|
||||||
ids, bp_vectors, bp_scores = backprops[i]
|
ids, bp_vectors, bp_scores = backprops[i]
|
||||||
|
|
|
@ -60,12 +60,12 @@ cdef class TransitionSystem:
|
||||||
offset += len(doc)
|
offset += len(doc)
|
||||||
return states
|
return states
|
||||||
|
|
||||||
def init_beams(self, docs, beam_width):
|
def init_beams(self, docs, beam_width, beam_density=0.):
|
||||||
cdef Doc doc
|
cdef Doc doc
|
||||||
beams = []
|
beams = []
|
||||||
cdef int offset = 0
|
cdef int offset = 0
|
||||||
for doc in docs:
|
for doc in docs:
|
||||||
beam = Beam(self.n_moves, beam_width)
|
beam = Beam(self.n_moves, beam_width, min_density=beam_density)
|
||||||
beam.initialize(self.init_beam_state, doc.length, doc.c)
|
beam.initialize(self.init_beam_state, doc.length, doc.c)
|
||||||
for i in range(beam.width):
|
for i in range(beam.width):
|
||||||
state = <StateC*>beam.at(i)
|
state = <StateC*>beam.at(i)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user