mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-10 16:40:34 +03:00
Restore the 'cutting' in parser training
This commit is contained in:
parent
6bda23ad26
commit
403b362a5d
|
@ -576,15 +576,20 @@ cdef class ArcEager(TransitionSystem):
|
|||
def is_gold_parse(self, StateClass state, gold):
|
||||
raise NotImplementedError
|
||||
|
||||
def init_gold(self, StateClass state, Example example):
|
||||
gold = ArcEagerGold(self, state, example)
|
||||
self._replace_unseen_labels(gold)
|
||||
return gold
|
||||
|
||||
def init_gold_batch(self, examples):
|
||||
states = self.init_batch([eg.predicted for eg in examples])
|
||||
keeps = [i for i, (eg, s) in enumerate(zip(examples, states))
|
||||
if self.has_gold(eg) and not s.is_final()]
|
||||
golds = [ArcEagerGold(self, states[i], examples[i]) for i in keeps]
|
||||
states = [states[i] for i in keeps]
|
||||
for gold in golds:
|
||||
self._replace_unseen_labels(gold)
|
||||
n_steps = sum([len(s.queue) * 4 for s in states])
|
||||
all_states = self.init_batch([eg.predicted for eg in examples])
|
||||
golds = []
|
||||
states = []
|
||||
for state, eg in zip(all_states, examples):
|
||||
if self.has_gold(eg) and not state.is_final():
|
||||
golds.append(self.init_gold(state, eg))
|
||||
states.append(state)
|
||||
n_steps = sum([len(s.queue) for s in states])
|
||||
return states, golds, n_steps
|
||||
|
||||
def _replace_unseen_labels(self, ArcEagerGold gold):
|
||||
|
@ -684,8 +689,12 @@ cdef class ArcEager(TransitionSystem):
|
|||
doc.is_parsed = True
|
||||
set_children_from_heads(doc.c, doc.length)
|
||||
|
||||
def has_gold(self, Example eg):
|
||||
return eg.y.is_parsed
|
||||
def has_gold(self, Example eg, start=0, end=None):
|
||||
for word in eg.y[start:end]:
|
||||
if word.dep != 0:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
cdef int set_valid(self, int* output, const StateC* st) nogil:
|
||||
cdef bint[N_MOVES] is_valid
|
||||
|
|
|
@ -130,11 +130,13 @@ cdef class BiluoPushDown(TransitionSystem):
|
|||
return MOVE_NAMES[move] + '-' + self.strings[label]
|
||||
|
||||
def init_gold_batch(self, examples):
|
||||
states = self.init_batch([eg.predicted for eg in examples])
|
||||
keeps = [i for i, (s, eg) in enumerate(zip(states, examples))
|
||||
if not s.is_final() and self.has_gold(eg)]
|
||||
golds = [BiluoGold(self, states[i], examples[i]) for i in keeps]
|
||||
states = [states[i] for i in keeps]
|
||||
all_states = self.init_batch([eg.predicted for eg in examples])
|
||||
golds = []
|
||||
states = []
|
||||
for state, eg in zip(all_states, examples):
|
||||
if self.has_gold(eg) and not state.is_final():
|
||||
golds.append(self.init_gold(state, eg))
|
||||
states.append(state)
|
||||
n_steps = sum([len(s.queue) for s in states])
|
||||
return states, golds, n_steps
|
||||
|
||||
|
@ -237,8 +239,15 @@ cdef class BiluoPushDown(TransitionSystem):
|
|||
self.add_action(UNIT, st._sent[i].ent_type)
|
||||
self.add_action(LAST, st._sent[i].ent_type)
|
||||
|
||||
def has_gold(self, Example eg):
|
||||
return eg.y.is_nered
|
||||
def init_gold(self, StateClass state, Example example):
|
||||
return BiluoGold(self, state, example)
|
||||
|
||||
def has_gold(self, Example eg, start=0, end=None):
|
||||
for word in eg.y[start:end]:
|
||||
if word.ent_iob != 0:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def get_cost(self, StateClass stcls, gold, int i):
|
||||
if not isinstance(gold, BiluoGold):
|
||||
|
|
|
@ -272,7 +272,7 @@ cdef class Parser:
|
|||
# Prepare the stepwise model, and get the callback for finishing the batch
|
||||
model, backprop_tok2vec = self.model.begin_update(
|
||||
[eg.predicted for eg in examples])
|
||||
states, golds, max_steps = self.moves.init_gold_batch(examples)
|
||||
states, golds, max_steps = self._init_gold_batch(examples)
|
||||
all_states = list(states)
|
||||
states_golds = zip(states, golds)
|
||||
for _ in range(max_steps):
|
||||
|
@ -490,3 +490,42 @@ cdef class Parser:
|
|||
except AttributeError:
|
||||
raise ValueError(Errors.E149)
|
||||
return self
|
||||
|
||||
def _init_gold_batch(self, examples, min_length=5, max_length=500):
|
||||
"""Make a square batch, of length equal to the shortest doc. A long
|
||||
doc will get multiple states. Let's say we have a doc of length 2*N,
|
||||
where N is the shortest doc. We'll make two states, one representing
|
||||
long_doc[:N], and another representing long_doc[N:]."""
|
||||
cdef:
|
||||
StateClass state
|
||||
Transition action
|
||||
all_states = self.moves.init_batch([eg.predicted for eg in examples])
|
||||
kept = []
|
||||
for state, eg in zip(all_states, examples):
|
||||
if self.moves.has_gold(eg) and not state.is_final():
|
||||
gold = self.moves.init_gold(state, eg)
|
||||
kept.append((eg, state, gold))
|
||||
max_length = max(min_length, min(max_length, min([len(eg.x) for eg in examples])))
|
||||
max_moves = 0
|
||||
states = []
|
||||
golds = []
|
||||
for eg, state, gold in kept:
|
||||
oracle_actions = self.moves.get_oracle_sequence(eg)
|
||||
start = 0
|
||||
while start < len(eg.predicted):
|
||||
state = state.copy()
|
||||
n_moves = 0
|
||||
while state.B(0) < start and not state.is_final():
|
||||
action = self.moves.c[oracle_actions.pop(0)]
|
||||
action.do(state.c, action.label)
|
||||
state.c.push_hist(action.clas)
|
||||
n_moves += 1
|
||||
has_gold = self.moves.has_gold(eg, start=start,
|
||||
end=start+max_length)
|
||||
if not state.is_final() and has_gold:
|
||||
states.append(state)
|
||||
golds.append(gold)
|
||||
max_moves = max(max_moves, n_moves)
|
||||
start += min(max_length, len(eg.x)-start)
|
||||
max_moves = max(max_moves, len(oracle_actions))
|
||||
return states, golds, max_moves
|
||||
|
|
Loading…
Reference in New Issue
Block a user