mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
Fixes for beam parsing. Not working
This commit is contained in:
parent
c96d769836
commit
28e930aae0
|
@ -60,10 +60,16 @@ cdef class ParserBeam(object):
|
|||
st = <StateClass>beam.at(i)
|
||||
st.c.offset = state.c.offset
|
||||
self.beams.append(beam)
|
||||
|
||||
def __dealloc__(self):
|
||||
if self.beams is not None:
|
||||
for beam in self.beams:
|
||||
if beam is not None:
|
||||
_cleanup(beam)
|
||||
|
||||
@property
|
||||
def is_done(self):
|
||||
return all(beam.is_done for beam in self.beams)
|
||||
return all(b.is_done for b in self.beams)
|
||||
|
||||
def __getitem__(self, i):
|
||||
return self.beams[i]
|
||||
|
@ -77,28 +83,31 @@ cdef class ParserBeam(object):
|
|||
self._set_scores(beam, scores[i])
|
||||
if self.golds is not None:
|
||||
self._set_costs(beam, self.golds[i], follow_gold=follow_gold)
|
||||
if follow_gold:
|
||||
assert self.golds is not None
|
||||
beam.advance(_transition_state, NULL, <void*>self.moves.c)
|
||||
else:
|
||||
beam.advance(_transition_state, _hash_state, <void*>self.moves.c)
|
||||
beam.check_done(_check_final_state, NULL)
|
||||
if follow_gold:
|
||||
assert self.golds is not None
|
||||
beam.advance(_transition_state, NULL, <void*>self.moves.c)
|
||||
else:
|
||||
beam.advance(_transition_state, _hash_state, <void*>self.moves.c)
|
||||
beam.check_done(_check_final_state, NULL)
|
||||
|
||||
def _set_scores(self, Beam beam, scores):
|
||||
def _set_scores(self, Beam beam, float[:, ::1] scores):
|
||||
cdef float* c_scores = &scores[0, 0]
|
||||
for i in range(beam.size):
|
||||
state = <StateClass>beam.at(i)
|
||||
if not state.is_final():
|
||||
for j in range(beam.nr_class):
|
||||
beam.scores[i][j] = scores[i, j]
|
||||
self.moves.set_valid(beam.is_valid[i], state.c)
|
||||
beam.scores[i][j] = c_scores[i * beam.nr_class + j]
|
||||
self.moves.set_valid(beam.is_valid[i], state.c)
|
||||
|
||||
def _set_costs(self, Beam beam, GoldParse gold, int follow_gold=False):
|
||||
for i in range(beam.size):
|
||||
state = <StateClass>beam.at(i)
|
||||
self.moves.set_costs(beam.is_valid[i], beam.costs[i], state, gold)
|
||||
if follow_gold:
|
||||
for j in range(beam.nr_class):
|
||||
beam.is_valid[i][j] *= beam.costs[i][j] <= 0
|
||||
if not state.c.is_final():
|
||||
self.moves.set_costs(beam.is_valid[i], beam.costs[i], state, gold)
|
||||
if follow_gold:
|
||||
for j in range(beam.nr_class):
|
||||
if beam.costs[i][j] >= 1:
|
||||
beam.is_valid[i][j] = 0
|
||||
|
||||
|
||||
def get_token_ids(states, int n_tokens):
|
||||
|
@ -122,7 +131,7 @@ def update_beam(TransitionSystem moves, int nr_feature, int max_steps,
|
|||
pbeam = ParserBeam(moves, states, golds,
|
||||
width=width, density=density)
|
||||
gbeam = ParserBeam(moves, states, golds,
|
||||
width=width, density=density)
|
||||
width=width, density=0.0)
|
||||
beam_maps = []
|
||||
backprops = []
|
||||
violns = [MaxViolation() for _ in range(len(states))]
|
||||
|
@ -145,7 +154,7 @@ def update_beam(TransitionSystem moves, int nr_feature, int max_steps,
|
|||
|
||||
for i, violn in enumerate(violns):
|
||||
violn.check_crf(pbeam[i], gbeam[i])
|
||||
|
||||
|
||||
histories = [(v.p_hist + v.g_hist) for v in violns]
|
||||
losses = [(v.p_probs + v.g_probs) for v in violns]
|
||||
states_d_scores = get_gradient(moves.n_moves, beam_maps,
|
||||
|
|
|
@ -66,6 +66,7 @@ from ..attrs cimport ID, TAG, DEP, ORTH, NORM, PREFIX, SUFFIX, TAG
|
|||
from . import _beam_utils
|
||||
|
||||
USE_FINE_TUNE = True
|
||||
BEAM_PARSE = True
|
||||
|
||||
def get_templates(*args, **kwargs):
|
||||
return []
|
||||
|
@ -335,7 +336,7 @@ cdef class Parser:
|
|||
return output
|
||||
|
||||
def pipe(self, docs, int batch_size=1000, int n_threads=2,
|
||||
beam_width=1, beam_density=0.001):
|
||||
beam_width=4, beam_density=0.001):
|
||||
"""
|
||||
Process a stream of documents.
|
||||
|
||||
|
@ -348,14 +349,18 @@ cdef class Parser:
|
|||
Yields (Doc): Documents, in order.
|
||||
"""
|
||||
cdef Doc doc
|
||||
cdef Beam beam
|
||||
for docs in cytoolz.partition_all(batch_size, docs):
|
||||
docs = list(docs)
|
||||
tokvecs = [doc.tensor for doc in docs]
|
||||
if beam_width == 1:
|
||||
parse_states = self.parse_batch(docs, tokvecs)
|
||||
else:
|
||||
parse_states = self.beam_parse(docs, tokvecs,
|
||||
beam_width=beam_width, beam_density=beam_density)
|
||||
beams = self.beam_parse(docs, tokvecs,
|
||||
beam_width=beam_width, beam_density=beam_density)
|
||||
parse_states = []
|
||||
for beam in beams:
|
||||
parse_states.append(<StateClass>beam.at(0))
|
||||
self.set_annotations(docs, parse_states)
|
||||
yield from docs
|
||||
|
||||
|
@ -462,6 +467,9 @@ cdef class Parser:
|
|||
return beams
|
||||
|
||||
def update(self, docs_tokvecs, golds, drop=0., sgd=None, losses=None):
|
||||
if BEAM_PARSE:
|
||||
return self.update_beam(docs_tokvecs, golds, drop=drop, sgd=sgd,
|
||||
losses=losses)
|
||||
if losses is not None and self.name not in losses:
|
||||
losses[self.name] = 0.
|
||||
docs, tokvec_lists = docs_tokvecs
|
||||
|
@ -528,9 +536,16 @@ cdef class Parser:
|
|||
return d_tokvecs
|
||||
|
||||
def update_beam(self, docs_tokvecs, golds, drop=0., sgd=None, losses=None):
|
||||
if losses is not None and self.name not in losses:
|
||||
losses[self.name] = 0.
|
||||
docs, tokvecs = docs_tokvecs
|
||||
lengths = [len(d) for d in docs]
|
||||
tokvecs = self.model[0].ops.flatten(tokvecs)
|
||||
if USE_FINE_TUNE:
|
||||
my_tokvecs, bp_my_tokvecs = self.model[0].begin_update(docs_tokvecs, drop=drop)
|
||||
my_tokvecs = self.model[0].ops.flatten(my_tokvecs)
|
||||
tokvecs += my_tokvecs
|
||||
|
||||
states, golds, max_moves = self._init_gold_batch(docs, golds)
|
||||
|
||||
cuda_stream = get_cuda_stream()
|
||||
|
@ -554,8 +569,10 @@ cdef class Parser:
|
|||
backprop_lower.append((ids, d_vector, bp_vectors))
|
||||
d_tokvecs = self.model[0].ops.allocate(tokvecs.shape)
|
||||
self._make_updates(d_tokvecs, backprop_lower, sgd, cuda_stream)
|
||||
lengths = [len(doc) for doc in docs]
|
||||
return self.model[0].ops.unflatten(d_tokvecs, lengths)
|
||||
d_tokvecs = self.model[0].ops.unflatten(d_tokvecs, lengths)
|
||||
if USE_FINE_TUNE:
|
||||
bp_my_tokvecs(d_tokvecs, sgd=sgd)
|
||||
return d_tokvecs
|
||||
|
||||
def _init_gold_batch(self, whole_docs, whole_golds):
|
||||
"""Make a square batch, of length equal to the shortest doc. A long
|
||||
|
|
Loading…
Reference in New Issue
Block a user