mirror of
https://github.com/explosion/spaCy.git
synced 2025-03-12 23:35:47 +03:00
Try to use real histories, not oracle
This commit is contained in:
parent
c3c462e562
commit
5b2440a1fd
|
@ -32,6 +32,7 @@ cdef cppclass StateC:
|
||||||
vector[ArcC] _left_arcs
|
vector[ArcC] _left_arcs
|
||||||
vector[ArcC] _right_arcs
|
vector[ArcC] _right_arcs
|
||||||
vector[libcpp.bool] _unshiftable
|
vector[libcpp.bool] _unshiftable
|
||||||
|
vector[int] history
|
||||||
set[int] _sent_starts
|
set[int] _sent_starts
|
||||||
TokenC _empty_token
|
TokenC _empty_token
|
||||||
int length
|
int length
|
||||||
|
@ -382,3 +383,4 @@ cdef cppclass StateC:
|
||||||
this._b_i = src._b_i
|
this._b_i = src._b_i
|
||||||
this.offset = src.offset
|
this.offset = src.offset
|
||||||
this._empty_token = src._empty_token
|
this._empty_token = src._empty_token
|
||||||
|
this.history = src.history
|
||||||
|
|
|
@ -844,6 +844,7 @@ cdef class ArcEager(TransitionSystem):
|
||||||
state.print_state()
|
state.print_state()
|
||||||
)))
|
)))
|
||||||
action.do(state.c, action.label)
|
action.do(state.c, action.label)
|
||||||
|
state.c.history.push_back(i)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
failed = False
|
failed = False
|
||||||
|
|
|
@ -20,6 +20,10 @@ cdef class StateClass:
|
||||||
if self._borrowed != 1:
|
if self._borrowed != 1:
|
||||||
del self.c
|
del self.c
|
||||||
|
|
||||||
|
@property
|
||||||
|
def history(self):
|
||||||
|
return list(self.c.history)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def stack(self):
|
def stack(self):
|
||||||
return [self.S(i) for i in range(self.c.stack_depth())]
|
return [self.S(i) for i in range(self.c.stack_depth())]
|
||||||
|
|
|
@ -67,6 +67,7 @@ cdef class TransitionSystem:
|
||||||
for clas in history:
|
for clas in history:
|
||||||
action = self.c[clas]
|
action = self.c[clas]
|
||||||
action.do(state.c, action.label)
|
action.do(state.c, action.label)
|
||||||
|
state.c.history.push_back(clas)
|
||||||
return state
|
return state
|
||||||
|
|
||||||
def get_oracle_sequence(self, Example example, _debug=False):
|
def get_oracle_sequence(self, Example example, _debug=False):
|
||||||
|
@ -110,6 +111,7 @@ cdef class TransitionSystem:
|
||||||
"S0 head?", str(state.has_head(state.S(0))),
|
"S0 head?", str(state.has_head(state.S(0))),
|
||||||
)))
|
)))
|
||||||
action.do(state.c, action.label)
|
action.do(state.c, action.label)
|
||||||
|
state.c.history.push_back(i)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
if _debug:
|
if _debug:
|
||||||
|
@ -137,6 +139,7 @@ cdef class TransitionSystem:
|
||||||
raise ValueError(Errors.E170.format(name=name))
|
raise ValueError(Errors.E170.format(name=name))
|
||||||
action = self.lookup_transition(name)
|
action = self.lookup_transition(name)
|
||||||
action.do(state.c, action.label)
|
action.do(state.c, action.label)
|
||||||
|
state.c.history.push_back(action.clas)
|
||||||
|
|
||||||
cdef Transition lookup_transition(self, object name) except *:
|
cdef Transition lookup_transition(self, object name) except *:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
|
@ -203,15 +203,21 @@ cdef class Parser(TrainablePipe):
|
||||||
)
|
)
|
||||||
|
|
||||||
def greedy_parse(self, docs, drop=0.):
|
def greedy_parse(self, docs, drop=0.):
|
||||||
cdef vector[StateC*] states
|
|
||||||
cdef StateClass state
|
|
||||||
set_dropout_rate(self.model, drop)
|
set_dropout_rate(self.model, drop)
|
||||||
batch = self.moves.init_batch(docs)
|
|
||||||
# This is pretty dirty, but the NER can resize itself in init_batch,
|
# This is pretty dirty, but the NER can resize itself in init_batch,
|
||||||
# if labels are missing. We therefore have to check whether we need to
|
# if labels are missing. We therefore have to check whether we need to
|
||||||
# expand our model output.
|
# expand our model output.
|
||||||
self._resize()
|
self._resize()
|
||||||
model = self.model.predict(docs)
|
model = self.model.predict(docs)
|
||||||
|
batch = self.moves.init_batch(docs)
|
||||||
|
states = self._predict_states(model, batch)
|
||||||
|
model.clear_memory()
|
||||||
|
del model
|
||||||
|
return states
|
||||||
|
|
||||||
|
def _predict_states(self, model, batch):
|
||||||
|
cdef vector[StateC*] states
|
||||||
|
cdef StateClass state
|
||||||
weights = get_c_weights(model)
|
weights = get_c_weights(model)
|
||||||
for state in batch:
|
for state in batch:
|
||||||
if not state.is_final():
|
if not state.is_final():
|
||||||
|
@ -220,8 +226,6 @@ cdef class Parser(TrainablePipe):
|
||||||
with nogil:
|
with nogil:
|
||||||
self._parseC(&states[0],
|
self._parseC(&states[0],
|
||||||
weights, sizes)
|
weights, sizes)
|
||||||
model.clear_memory()
|
|
||||||
del model
|
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
def beam_parse(self, docs, int beam_width, float drop=0., beam_density=0.):
|
def beam_parse(self, docs, int beam_width, float drop=0., beam_density=0.):
|
||||||
|
@ -306,6 +310,7 @@ cdef class Parser(TrainablePipe):
|
||||||
else:
|
else:
|
||||||
action = self.moves.c[guess]
|
action = self.moves.c[guess]
|
||||||
action.do(states[i], action.label)
|
action.do(states[i], action.label)
|
||||||
|
states[i].history.push_back(guess)
|
||||||
free(is_valid)
|
free(is_valid)
|
||||||
|
|
||||||
def update(self, examples, *, drop=0., sgd=None, losses=None):
|
def update(self, examples, *, drop=0., sgd=None, losses=None):
|
||||||
|
@ -319,7 +324,7 @@ cdef class Parser(TrainablePipe):
|
||||||
# We need to take care to act on the whole batch, because we might be
|
# We need to take care to act on the whole batch, because we might be
|
||||||
# getting vectors via a listener.
|
# getting vectors via a listener.
|
||||||
n_examples = len([eg for eg in examples if self.moves.has_gold(eg)])
|
n_examples = len([eg for eg in examples if self.moves.has_gold(eg)])
|
||||||
if len(examples) == 0:
|
if n_examples == 0:
|
||||||
return losses
|
return losses
|
||||||
set_dropout_rate(self.model, drop)
|
set_dropout_rate(self.model, drop)
|
||||||
# The probability we use beam update, instead of falling back to
|
# The probability we use beam update, instead of falling back to
|
||||||
|
@ -333,7 +338,11 @@ cdef class Parser(TrainablePipe):
|
||||||
losses=losses,
|
losses=losses,
|
||||||
beam_density=self.cfg["beam_density"]
|
beam_density=self.cfg["beam_density"]
|
||||||
)
|
)
|
||||||
oracle_histories = [self.moves.get_oracle_sequence(eg) for eg in examples]
|
model, backprop_tok2vec = self.model.begin_update([eg.x for eg in examples])
|
||||||
|
final_states = self.moves.init_batch([eg.x for eg in examples])
|
||||||
|
self._predict_states(model, final_states)
|
||||||
|
histories = [list(state.history) for state in final_states]
|
||||||
|
#oracle_histories = [self.moves.get_oracle_sequence(eg) for eg in examples]
|
||||||
max_moves = self.cfg["update_with_oracle_cut_size"]
|
max_moves = self.cfg["update_with_oracle_cut_size"]
|
||||||
if max_moves >= 1:
|
if max_moves >= 1:
|
||||||
# Chop sequences into lengths of this many words, to make the
|
# Chop sequences into lengths of this many words, to make the
|
||||||
|
@ -341,15 +350,13 @@ cdef class Parser(TrainablePipe):
|
||||||
max_moves = int(random.uniform(max_moves // 2, max_moves * 2))
|
max_moves = int(random.uniform(max_moves // 2, max_moves * 2))
|
||||||
states, golds, _ = self._init_gold_batch(
|
states, golds, _ = self._init_gold_batch(
|
||||||
examples,
|
examples,
|
||||||
oracle_histories,
|
histories,
|
||||||
max_length=max_moves
|
max_length=max_moves
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
states, golds, _ = self.moves.init_gold_batch(examples)
|
states, golds, _ = self.moves.init_gold_batch(examples)
|
||||||
if not states:
|
if not states:
|
||||||
return losses
|
return losses
|
||||||
docs = [eg.predicted for eg in examples]
|
|
||||||
model, backprop_tok2vec = self.model.begin_update(docs)
|
|
||||||
|
|
||||||
all_states = list(states)
|
all_states = list(states)
|
||||||
states_golds = list(zip(states, golds))
|
states_golds = list(zip(states, golds))
|
||||||
|
@ -373,15 +380,7 @@ cdef class Parser(TrainablePipe):
|
||||||
backprop_tok2vec(golds)
|
backprop_tok2vec(golds)
|
||||||
if sgd not in (None, False):
|
if sgd not in (None, False):
|
||||||
self.finish_update(sgd)
|
self.finish_update(sgd)
|
||||||
# If we want to set the annotations based on predictions, it's really
|
self.set_annotations([eg.x for eg in examples], final_states)
|
||||||
# hard to avoid parsing the data twice :(.
|
|
||||||
# The issue is that we cut up the gold batch into sub-states, and that
|
|
||||||
# means there's no one predicted sequence during the update.
|
|
||||||
gold_states = [
|
|
||||||
self.moves.follow_history(doc, history)
|
|
||||||
for doc, history in zip(docs, oracle_histories)
|
|
||||||
]
|
|
||||||
self.set_annotations(docs, gold_states)
|
|
||||||
# Ugh, this is annoying. If we're working on GPU, we want to free the
|
# Ugh, this is annoying. If we're working on GPU, we want to free the
|
||||||
# memory ASAP. It seems that Python doesn't necessarily get around to
|
# memory ASAP. It seems that Python doesn't necessarily get around to
|
||||||
# removing these in time if we don't explicitly delete? It's confusing.
|
# removing these in time if we don't explicitly delete? It's confusing.
|
||||||
|
@ -599,6 +598,7 @@ cdef class Parser(TrainablePipe):
|
||||||
StateClass state
|
StateClass state
|
||||||
Transition action
|
Transition action
|
||||||
all_states = self.moves.init_batch([eg.predicted for eg in examples])
|
all_states = self.moves.init_batch([eg.predicted for eg in examples])
|
||||||
|
assert len(all_states) == len(examples) == len(oracle_histories)
|
||||||
states = []
|
states = []
|
||||||
golds = []
|
golds = []
|
||||||
for state, eg, history in zip(all_states, examples, oracle_histories):
|
for state, eg, history in zip(all_states, examples, oracle_histories):
|
||||||
|
@ -616,6 +616,7 @@ cdef class Parser(TrainablePipe):
|
||||||
for clas in history[i:i+max_length]:
|
for clas in history[i:i+max_length]:
|
||||||
action = self.moves.c[clas]
|
action = self.moves.c[clas]
|
||||||
action.do(state.c, action.label)
|
action.do(state.c, action.label)
|
||||||
|
state.c.history.push_back(clas)
|
||||||
if state.is_final():
|
if state.is_final():
|
||||||
break
|
break
|
||||||
if self.moves.has_gold(eg, start_state.B(0), state.B(0)):
|
if self.moves.has_gold(eg, start_state.B(0), state.B(0)):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user