Try to use real histories, not oracle

This commit is contained in:
Matthew Honnibal 2021-01-25 18:59:52 +11:00
parent c3c462e562
commit 5b2440a1fd
5 changed files with 30 additions and 19 deletions

View File

@ -32,6 +32,7 @@ cdef cppclass StateC:
vector[ArcC] _left_arcs vector[ArcC] _left_arcs
vector[ArcC] _right_arcs vector[ArcC] _right_arcs
vector[libcpp.bool] _unshiftable vector[libcpp.bool] _unshiftable
vector[int] history
set[int] _sent_starts set[int] _sent_starts
TokenC _empty_token TokenC _empty_token
int length int length
@ -382,3 +383,4 @@ cdef cppclass StateC:
this._b_i = src._b_i this._b_i = src._b_i
this.offset = src.offset this.offset = src.offset
this._empty_token = src._empty_token this._empty_token = src._empty_token
this.history = src.history

View File

@ -844,6 +844,7 @@ cdef class ArcEager(TransitionSystem):
state.print_state() state.print_state()
))) )))
action.do(state.c, action.label) action.do(state.c, action.label)
state.c.history.push_back(i)
break break
else: else:
failed = False failed = False

View File

@ -20,6 +20,10 @@ cdef class StateClass:
if self._borrowed != 1: if self._borrowed != 1:
del self.c del self.c
@property
def history(self):
return list(self.c.history)
@property @property
def stack(self): def stack(self):
return [self.S(i) for i in range(self.c.stack_depth())] return [self.S(i) for i in range(self.c.stack_depth())]

View File

@ -67,6 +67,7 @@ cdef class TransitionSystem:
for clas in history: for clas in history:
action = self.c[clas] action = self.c[clas]
action.do(state.c, action.label) action.do(state.c, action.label)
state.c.history.push_back(clas)
return state return state
def get_oracle_sequence(self, Example example, _debug=False): def get_oracle_sequence(self, Example example, _debug=False):
@ -110,6 +111,7 @@ cdef class TransitionSystem:
"S0 head?", str(state.has_head(state.S(0))), "S0 head?", str(state.has_head(state.S(0))),
))) )))
action.do(state.c, action.label) action.do(state.c, action.label)
state.c.history.push_back(i)
break break
else: else:
if _debug: if _debug:
@ -137,6 +139,7 @@ cdef class TransitionSystem:
raise ValueError(Errors.E170.format(name=name)) raise ValueError(Errors.E170.format(name=name))
action = self.lookup_transition(name) action = self.lookup_transition(name)
action.do(state.c, action.label) action.do(state.c, action.label)
state.c.history.push_back(action.clas)
cdef Transition lookup_transition(self, object name) except *: cdef Transition lookup_transition(self, object name) except *:
raise NotImplementedError raise NotImplementedError

View File

@ -203,15 +203,21 @@ cdef class Parser(TrainablePipe):
) )
def greedy_parse(self, docs, drop=0.): def greedy_parse(self, docs, drop=0.):
cdef vector[StateC*] states
cdef StateClass state
set_dropout_rate(self.model, drop) set_dropout_rate(self.model, drop)
batch = self.moves.init_batch(docs)
# This is pretty dirty, but the NER can resize itself in init_batch, # This is pretty dirty, but the NER can resize itself in init_batch,
# if labels are missing. We therefore have to check whether we need to # if labels are missing. We therefore have to check whether we need to
# expand our model output. # expand our model output.
self._resize() self._resize()
model = self.model.predict(docs) model = self.model.predict(docs)
batch = self.moves.init_batch(docs)
states = self._predict_states(model, batch)
model.clear_memory()
del model
return states
def _predict_states(self, model, batch):
cdef vector[StateC*] states
cdef StateClass state
weights = get_c_weights(model) weights = get_c_weights(model)
for state in batch: for state in batch:
if not state.is_final(): if not state.is_final():
@ -220,8 +226,6 @@ cdef class Parser(TrainablePipe):
with nogil: with nogil:
self._parseC(&states[0], self._parseC(&states[0],
weights, sizes) weights, sizes)
model.clear_memory()
del model
return batch return batch
def beam_parse(self, docs, int beam_width, float drop=0., beam_density=0.): def beam_parse(self, docs, int beam_width, float drop=0., beam_density=0.):
@ -306,6 +310,7 @@ cdef class Parser(TrainablePipe):
else: else:
action = self.moves.c[guess] action = self.moves.c[guess]
action.do(states[i], action.label) action.do(states[i], action.label)
states[i].history.push_back(guess)
free(is_valid) free(is_valid)
def update(self, examples, *, drop=0., sgd=None, losses=None): def update(self, examples, *, drop=0., sgd=None, losses=None):
@ -319,7 +324,7 @@ cdef class Parser(TrainablePipe):
# We need to take care to act on the whole batch, because we might be # We need to take care to act on the whole batch, because we might be
# getting vectors via a listener. # getting vectors via a listener.
n_examples = len([eg for eg in examples if self.moves.has_gold(eg)]) n_examples = len([eg for eg in examples if self.moves.has_gold(eg)])
if len(examples) == 0: if n_examples == 0:
return losses return losses
set_dropout_rate(self.model, drop) set_dropout_rate(self.model, drop)
# The probability we use beam update, instead of falling back to # The probability we use beam update, instead of falling back to
@ -333,7 +338,11 @@ cdef class Parser(TrainablePipe):
losses=losses, losses=losses,
beam_density=self.cfg["beam_density"] beam_density=self.cfg["beam_density"]
) )
oracle_histories = [self.moves.get_oracle_sequence(eg) for eg in examples] model, backprop_tok2vec = self.model.begin_update([eg.x for eg in examples])
final_states = self.moves.init_batch([eg.x for eg in examples])
self._predict_states(model, final_states)
histories = [list(state.history) for state in final_states]
#oracle_histories = [self.moves.get_oracle_sequence(eg) for eg in examples]
max_moves = self.cfg["update_with_oracle_cut_size"] max_moves = self.cfg["update_with_oracle_cut_size"]
if max_moves >= 1: if max_moves >= 1:
# Chop sequences into lengths of this many words, to make the # Chop sequences into lengths of this many words, to make the
@ -341,15 +350,13 @@ cdef class Parser(TrainablePipe):
max_moves = int(random.uniform(max_moves // 2, max_moves * 2)) max_moves = int(random.uniform(max_moves // 2, max_moves * 2))
states, golds, _ = self._init_gold_batch( states, golds, _ = self._init_gold_batch(
examples, examples,
oracle_histories, histories,
max_length=max_moves max_length=max_moves
) )
else: else:
states, golds, _ = self.moves.init_gold_batch(examples) states, golds, _ = self.moves.init_gold_batch(examples)
if not states: if not states:
return losses return losses
docs = [eg.predicted for eg in examples]
model, backprop_tok2vec = self.model.begin_update(docs)
all_states = list(states) all_states = list(states)
states_golds = list(zip(states, golds)) states_golds = list(zip(states, golds))
@ -373,15 +380,7 @@ cdef class Parser(TrainablePipe):
backprop_tok2vec(golds) backprop_tok2vec(golds)
if sgd not in (None, False): if sgd not in (None, False):
self.finish_update(sgd) self.finish_update(sgd)
# If we want to set the annotations based on predictions, it's really self.set_annotations([eg.x for eg in examples], final_states)
# hard to avoid parsing the data twice :(.
# The issue is that we cut up the gold batch into sub-states, and that
# means there's no one predicted sequence during the update.
gold_states = [
self.moves.follow_history(doc, history)
for doc, history in zip(docs, oracle_histories)
]
self.set_annotations(docs, gold_states)
# Ugh, this is annoying. If we're working on GPU, we want to free the # Ugh, this is annoying. If we're working on GPU, we want to free the
# memory ASAP. It seems that Python doesn't necessarily get around to # memory ASAP. It seems that Python doesn't necessarily get around to
# removing these in time if we don't explicitly delete? It's confusing. # removing these in time if we don't explicitly delete? It's confusing.
@ -599,6 +598,7 @@ cdef class Parser(TrainablePipe):
StateClass state StateClass state
Transition action Transition action
all_states = self.moves.init_batch([eg.predicted for eg in examples]) all_states = self.moves.init_batch([eg.predicted for eg in examples])
assert len(all_states) == len(examples) == len(oracle_histories)
states = [] states = []
golds = [] golds = []
for state, eg, history in zip(all_states, examples, oracle_histories): for state, eg, history in zip(all_states, examples, oracle_histories):
@ -616,6 +616,7 @@ cdef class Parser(TrainablePipe):
for clas in history[i:i+max_length]: for clas in history[i:i+max_length]:
action = self.moves.c[clas] action = self.moves.c[clas]
action.do(state.c, action.label) action.do(state.c, action.label)
state.c.history.push_back(clas)
if state.is_final(): if state.is_final():
break break
if self.moves.has_gold(eg, start_state.B(0), state.B(0)): if self.moves.has_gold(eg, start_state.B(0), state.B(0)):