mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-11 09:00:36 +03:00
Restore the 'cutting' in parser training
This commit is contained in:
parent
6bda23ad26
commit
403b362a5d
|
@ -576,15 +576,20 @@ cdef class ArcEager(TransitionSystem):
|
||||||
def is_gold_parse(self, StateClass state, gold):
|
def is_gold_parse(self, StateClass state, gold):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def init_gold(self, StateClass state, Example example):
|
||||||
|
gold = ArcEagerGold(self, state, example)
|
||||||
|
self._replace_unseen_labels(gold)
|
||||||
|
return gold
|
||||||
|
|
||||||
def init_gold_batch(self, examples):
|
def init_gold_batch(self, examples):
|
||||||
states = self.init_batch([eg.predicted for eg in examples])
|
all_states = self.init_batch([eg.predicted for eg in examples])
|
||||||
keeps = [i for i, (eg, s) in enumerate(zip(examples, states))
|
golds = []
|
||||||
if self.has_gold(eg) and not s.is_final()]
|
states = []
|
||||||
golds = [ArcEagerGold(self, states[i], examples[i]) for i in keeps]
|
for state, eg in zip(all_states, examples):
|
||||||
states = [states[i] for i in keeps]
|
if self.has_gold(eg) and not state.is_final():
|
||||||
for gold in golds:
|
golds.append(self.init_gold(state, eg))
|
||||||
self._replace_unseen_labels(gold)
|
states.append(state)
|
||||||
n_steps = sum([len(s.queue) * 4 for s in states])
|
n_steps = sum([len(s.queue) for s in states])
|
||||||
return states, golds, n_steps
|
return states, golds, n_steps
|
||||||
|
|
||||||
def _replace_unseen_labels(self, ArcEagerGold gold):
|
def _replace_unseen_labels(self, ArcEagerGold gold):
|
||||||
|
@ -684,8 +689,12 @@ cdef class ArcEager(TransitionSystem):
|
||||||
doc.is_parsed = True
|
doc.is_parsed = True
|
||||||
set_children_from_heads(doc.c, doc.length)
|
set_children_from_heads(doc.c, doc.length)
|
||||||
|
|
||||||
def has_gold(self, Example eg):
|
def has_gold(self, Example eg, start=0, end=None):
|
||||||
return eg.y.is_parsed
|
for word in eg.y[start:end]:
|
||||||
|
if word.dep != 0:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
cdef int set_valid(self, int* output, const StateC* st) nogil:
|
cdef int set_valid(self, int* output, const StateC* st) nogil:
|
||||||
cdef bint[N_MOVES] is_valid
|
cdef bint[N_MOVES] is_valid
|
||||||
|
|
|
@ -130,11 +130,13 @@ cdef class BiluoPushDown(TransitionSystem):
|
||||||
return MOVE_NAMES[move] + '-' + self.strings[label]
|
return MOVE_NAMES[move] + '-' + self.strings[label]
|
||||||
|
|
||||||
def init_gold_batch(self, examples):
|
def init_gold_batch(self, examples):
|
||||||
states = self.init_batch([eg.predicted for eg in examples])
|
all_states = self.init_batch([eg.predicted for eg in examples])
|
||||||
keeps = [i for i, (s, eg) in enumerate(zip(states, examples))
|
golds = []
|
||||||
if not s.is_final() and self.has_gold(eg)]
|
states = []
|
||||||
golds = [BiluoGold(self, states[i], examples[i]) for i in keeps]
|
for state, eg in zip(all_states, examples):
|
||||||
states = [states[i] for i in keeps]
|
if self.has_gold(eg) and not state.is_final():
|
||||||
|
golds.append(self.init_gold(state, eg))
|
||||||
|
states.append(state)
|
||||||
n_steps = sum([len(s.queue) for s in states])
|
n_steps = sum([len(s.queue) for s in states])
|
||||||
return states, golds, n_steps
|
return states, golds, n_steps
|
||||||
|
|
||||||
|
@ -237,8 +239,15 @@ cdef class BiluoPushDown(TransitionSystem):
|
||||||
self.add_action(UNIT, st._sent[i].ent_type)
|
self.add_action(UNIT, st._sent[i].ent_type)
|
||||||
self.add_action(LAST, st._sent[i].ent_type)
|
self.add_action(LAST, st._sent[i].ent_type)
|
||||||
|
|
||||||
def has_gold(self, Example eg):
|
def init_gold(self, StateClass state, Example example):
|
||||||
return eg.y.is_nered
|
return BiluoGold(self, state, example)
|
||||||
|
|
||||||
|
def has_gold(self, Example eg, start=0, end=None):
|
||||||
|
for word in eg.y[start:end]:
|
||||||
|
if word.ent_iob != 0:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
def get_cost(self, StateClass stcls, gold, int i):
|
def get_cost(self, StateClass stcls, gold, int i):
|
||||||
if not isinstance(gold, BiluoGold):
|
if not isinstance(gold, BiluoGold):
|
||||||
|
|
|
@ -272,7 +272,7 @@ cdef class Parser:
|
||||||
# Prepare the stepwise model, and get the callback for finishing the batch
|
# Prepare the stepwise model, and get the callback for finishing the batch
|
||||||
model, backprop_tok2vec = self.model.begin_update(
|
model, backprop_tok2vec = self.model.begin_update(
|
||||||
[eg.predicted for eg in examples])
|
[eg.predicted for eg in examples])
|
||||||
states, golds, max_steps = self.moves.init_gold_batch(examples)
|
states, golds, max_steps = self._init_gold_batch(examples)
|
||||||
all_states = list(states)
|
all_states = list(states)
|
||||||
states_golds = zip(states, golds)
|
states_golds = zip(states, golds)
|
||||||
for _ in range(max_steps):
|
for _ in range(max_steps):
|
||||||
|
@ -490,3 +490,42 @@ cdef class Parser:
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
raise ValueError(Errors.E149)
|
raise ValueError(Errors.E149)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def _init_gold_batch(self, examples, min_length=5, max_length=500):
|
||||||
|
"""Make a square batch, of length equal to the shortest doc. A long
|
||||||
|
doc will get multiple states. Let's say we have a doc of length 2*N,
|
||||||
|
where N is the shortest doc. We'll make two states, one representing
|
||||||
|
long_doc[:N], and another representing long_doc[N:]."""
|
||||||
|
cdef:
|
||||||
|
StateClass state
|
||||||
|
Transition action
|
||||||
|
all_states = self.moves.init_batch([eg.predicted for eg in examples])
|
||||||
|
kept = []
|
||||||
|
for state, eg in zip(all_states, examples):
|
||||||
|
if self.moves.has_gold(eg) and not state.is_final():
|
||||||
|
gold = self.moves.init_gold(state, eg)
|
||||||
|
kept.append((eg, state, gold))
|
||||||
|
max_length = max(min_length, min(max_length, min([len(eg.x) for eg in examples])))
|
||||||
|
max_moves = 0
|
||||||
|
states = []
|
||||||
|
golds = []
|
||||||
|
for eg, state, gold in kept:
|
||||||
|
oracle_actions = self.moves.get_oracle_sequence(eg)
|
||||||
|
start = 0
|
||||||
|
while start < len(eg.predicted):
|
||||||
|
state = state.copy()
|
||||||
|
n_moves = 0
|
||||||
|
while state.B(0) < start and not state.is_final():
|
||||||
|
action = self.moves.c[oracle_actions.pop(0)]
|
||||||
|
action.do(state.c, action.label)
|
||||||
|
state.c.push_hist(action.clas)
|
||||||
|
n_moves += 1
|
||||||
|
has_gold = self.moves.has_gold(eg, start=start,
|
||||||
|
end=start+max_length)
|
||||||
|
if not state.is_final() and has_gold:
|
||||||
|
states.append(state)
|
||||||
|
golds.append(gold)
|
||||||
|
max_moves = max(max_moves, n_moves)
|
||||||
|
start += min(max_length, len(eg.x)-start)
|
||||||
|
max_moves = max(max_moves, len(oracle_actions))
|
||||||
|
return states, golds, max_moves
|
||||||
|
|
Loading…
Reference in New Issue
Block a user