Restore the 'cutting' in parser training

This commit is contained in:
Matthew Honnibal 2020-06-25 18:52:42 +02:00
parent 6bda23ad26
commit 403b362a5d
3 changed files with 75 additions and 18 deletions

View File

@ -576,15 +576,20 @@ cdef class ArcEager(TransitionSystem):
def is_gold_parse(self, StateClass state, gold):
raise NotImplementedError
def init_gold(self, StateClass state, Example example):
gold = ArcEagerGold(self, state, example)
self._replace_unseen_labels(gold)
return gold
def init_gold_batch(self, examples):
states = self.init_batch([eg.predicted for eg in examples])
keeps = [i for i, (eg, s) in enumerate(zip(examples, states))
if self.has_gold(eg) and not s.is_final()]
golds = [ArcEagerGold(self, states[i], examples[i]) for i in keeps]
states = [states[i] for i in keeps]
for gold in golds:
self._replace_unseen_labels(gold)
n_steps = sum([len(s.queue) * 4 for s in states])
all_states = self.init_batch([eg.predicted for eg in examples])
golds = []
states = []
for state, eg in zip(all_states, examples):
if self.has_gold(eg) and not state.is_final():
golds.append(self.init_gold(state, eg))
states.append(state)
n_steps = sum([len(s.queue) for s in states])
return states, golds, n_steps
def _replace_unseen_labels(self, ArcEagerGold gold):
@ -684,8 +689,12 @@ cdef class ArcEager(TransitionSystem):
doc.is_parsed = True
set_children_from_heads(doc.c, doc.length)
def has_gold(self, Example eg):
return eg.y.is_parsed
def has_gold(self, Example eg, start=0, end=None):
for word in eg.y[start:end]:
if word.dep != 0:
return True
else:
return False
cdef int set_valid(self, int* output, const StateC* st) nogil:
cdef bint[N_MOVES] is_valid

View File

@ -130,11 +130,13 @@ cdef class BiluoPushDown(TransitionSystem):
return MOVE_NAMES[move] + '-' + self.strings[label]
def init_gold_batch(self, examples):
states = self.init_batch([eg.predicted for eg in examples])
keeps = [i for i, (s, eg) in enumerate(zip(states, examples))
if not s.is_final() and self.has_gold(eg)]
golds = [BiluoGold(self, states[i], examples[i]) for i in keeps]
states = [states[i] for i in keeps]
all_states = self.init_batch([eg.predicted for eg in examples])
golds = []
states = []
for state, eg in zip(all_states, examples):
if self.has_gold(eg) and not state.is_final():
golds.append(self.init_gold(state, eg))
states.append(state)
n_steps = sum([len(s.queue) for s in states])
return states, golds, n_steps
@ -237,8 +239,15 @@ cdef class BiluoPushDown(TransitionSystem):
self.add_action(UNIT, st._sent[i].ent_type)
self.add_action(LAST, st._sent[i].ent_type)
def has_gold(self, Example eg):
return eg.y.is_nered
def init_gold(self, StateClass state, Example example):
return BiluoGold(self, state, example)
def has_gold(self, Example eg, start=0, end=None):
for word in eg.y[start:end]:
if word.ent_iob != 0:
return True
else:
return False
def get_cost(self, StateClass stcls, gold, int i):
if not isinstance(gold, BiluoGold):

View File

@ -272,7 +272,7 @@ cdef class Parser:
# Prepare the stepwise model, and get the callback for finishing the batch
model, backprop_tok2vec = self.model.begin_update(
[eg.predicted for eg in examples])
states, golds, max_steps = self.moves.init_gold_batch(examples)
states, golds, max_steps = self._init_gold_batch(examples)
all_states = list(states)
states_golds = zip(states, golds)
for _ in range(max_steps):
@ -490,3 +490,42 @@ cdef class Parser:
except AttributeError:
raise ValueError(Errors.E149)
return self
def _init_gold_batch(self, examples, min_length=5, max_length=500):
"""Make a square batch, of length equal to the shortest doc. A long
doc will get multiple states. Let's say we have a doc of length 2*N,
where N is the shortest doc. We'll make two states, one representing
long_doc[:N], and another representing long_doc[N:]."""
cdef:
StateClass state
Transition action
all_states = self.moves.init_batch([eg.predicted for eg in examples])
kept = []
for state, eg in zip(all_states, examples):
if self.moves.has_gold(eg) and not state.is_final():
gold = self.moves.init_gold(state, eg)
kept.append((eg, state, gold))
max_length = max(min_length, min(max_length, min([len(eg.x) for eg in examples])))
max_moves = 0
states = []
golds = []
for eg, state, gold in kept:
oracle_actions = self.moves.get_oracle_sequence(eg)
start = 0
while start < len(eg.predicted):
state = state.copy()
n_moves = 0
while state.B(0) < start and not state.is_final():
action = self.moves.c[oracle_actions.pop(0)]
action.do(state.c, action.label)
state.c.push_hist(action.clas)
n_moves += 1
has_gold = self.moves.has_gold(eg, start=start,
end=start+max_length)
if not state.is_final() and has_gold:
states.append(state)
golds.append(gold)
max_moves = max(max_moves, n_moves)
start += min(max_length, len(eg.x)-start)
max_moves = max(max_moves, len(oracle_actions))
return states, golds, max_moves