mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 09:56:28 +03:00
Fix infinite loop bug
This commit is contained in:
parent
94e86ae00a
commit
50ddc9fc45
|
@ -114,7 +114,7 @@ cdef class Parser:
|
||||||
def __reduce__(self):
|
def __reduce__(self):
|
||||||
return (Parser, (self.vocab, self.moves, self.model), None, None)
|
return (Parser, (self.vocab, self.moves, self.model), None, None)
|
||||||
|
|
||||||
def build_model(self, width=64, nr_vector=1000, nF=1, nB=1, nS=1, nL=1, nR=1, **_):
|
def build_model(self, width=128, nr_vector=1000, nF=1, nB=1, nS=1, nL=1, nR=1, **_):
|
||||||
nr_context_tokens = StateClass.nr_context_tokens(nF, nB, nS, nL, nR)
|
nr_context_tokens = StateClass.nr_context_tokens(nF, nB, nS, nL, nR)
|
||||||
state2vec = build_state2vec(nr_context_tokens, width, nr_vector)
|
state2vec = build_state2vec(nr_context_tokens, width, nr_vector)
|
||||||
#state2vec = build_debug_state2vec(width, nr_vector)
|
#state2vec = build_debug_state2vec(width, nr_vector)
|
||||||
|
@ -175,17 +175,13 @@ cdef class Parser:
|
||||||
tokvecs = [d.tensor for d in docs]
|
tokvecs = [d.tensor for d in docs]
|
||||||
all_states = list(states)
|
all_states = list(states)
|
||||||
todo = zip(states, tokvecs)
|
todo = zip(states, tokvecs)
|
||||||
i = 0
|
|
||||||
while todo:
|
while todo:
|
||||||
|
todo = filter(lambda sp: not sp[0].py_is_final(), todo)
|
||||||
|
if not todo:
|
||||||
|
break
|
||||||
states, tokvecs = zip(*todo)
|
states, tokvecs = zip(*todo)
|
||||||
scores, _ = self._begin_update(states, tokvecs)
|
scores, _ = self._begin_update(states, tokvecs)
|
||||||
for state, guess in zip(states, scores.argmax(axis=1)):
|
self._transition_batch(states, docs, scores)
|
||||||
action = self.moves.c[guess]
|
|
||||||
action.do(state.c, action.label)
|
|
||||||
todo = filter(lambda sp: not sp[0].py_is_final(), todo)
|
|
||||||
i += 1
|
|
||||||
if i >= 10000:
|
|
||||||
break
|
|
||||||
for state, doc in zip(all_states, docs):
|
for state, doc in zip(all_states, docs):
|
||||||
self.moves.finalize_state(state.c)
|
self.moves.finalize_state(state.c)
|
||||||
for i in range(doc.length):
|
for i in range(doc.length):
|
||||||
|
@ -208,7 +204,6 @@ cdef class Parser:
|
||||||
features = self._get_features(states, tokvecs, attr_names)
|
features = self._get_features(states, tokvecs, attr_names)
|
||||||
self.model.begin_training(features)
|
self.model.begin_training(features)
|
||||||
|
|
||||||
|
|
||||||
def update(self, docs, golds, drop=0., sgd=None):
|
def update(self, docs, golds, drop=0., sgd=None):
|
||||||
if isinstance(docs, Doc) and isinstance(golds, GoldParse):
|
if isinstance(docs, Doc) and isinstance(golds, GoldParse):
|
||||||
return self.update([docs], [golds], drop=drop)
|
return self.update([docs], [golds], drop=drop)
|
||||||
|
@ -222,12 +217,17 @@ cdef class Parser:
|
||||||
todo = zip(states, tokvecs, golds, d_tokens)
|
todo = zip(states, tokvecs, golds, d_tokens)
|
||||||
assert len(states) == len(todo)
|
assert len(states) == len(todo)
|
||||||
losses = []
|
losses = []
|
||||||
i = 0
|
|
||||||
while todo:
|
while todo:
|
||||||
|
# Get unfinished states (and their matching gold and token gradients)
|
||||||
|
todo = filter(lambda sp: not sp[0].py_is_final(), todo)
|
||||||
|
if not todo:
|
||||||
|
break
|
||||||
states, tokvecs, golds, d_tokens = zip(*todo)
|
states, tokvecs, golds, d_tokens = zip(*todo)
|
||||||
scores, finish_update = self._begin_update(states, tokvecs)
|
scores, finish_update = self._begin_update(states, tokvecs)
|
||||||
token_ids, batch_token_grads = finish_update(golds, sgd=sgd, losses=losses,
|
token_ids, batch_token_grads = finish_update(golds, sgd=sgd, losses=losses,
|
||||||
force_gold=False)
|
force_gold=False)
|
||||||
|
batch_token_grads *= (token_ids >= 0).reshape((token_ids.shape[0], token_ids.shape[1], 1))
|
||||||
|
token_ids *= token_ids >= 0
|
||||||
if hasattr(self.model.ops.xp, 'scatter_add'):
|
if hasattr(self.model.ops.xp, 'scatter_add'):
|
||||||
for i, tok_ids in enumerate(token_ids):
|
for i, tok_ids in enumerate(token_ids):
|
||||||
self.model.ops.xp.scatter_add(d_tokens[i],
|
self.model.ops.xp.scatter_add(d_tokens[i],
|
||||||
|
@ -236,14 +236,7 @@ cdef class Parser:
|
||||||
for i, tok_ids in enumerate(token_ids):
|
for i, tok_ids in enumerate(token_ids):
|
||||||
self.model.ops.xp.add.at(d_tokens[i],
|
self.model.ops.xp.add.at(d_tokens[i],
|
||||||
tok_ids, batch_token_grads[i])
|
tok_ids, batch_token_grads[i])
|
||||||
|
self._transition_batch(states, docs, scores)
|
||||||
self._transition_batch(states, scores)
|
|
||||||
|
|
||||||
# Get unfinished states (and their matching gold and token gradients)
|
|
||||||
todo = filter(lambda sp: not sp[0].py_is_final(), todo)
|
|
||||||
i += 1
|
|
||||||
if i >= 10000:
|
|
||||||
break
|
|
||||||
return output, sum(losses)
|
return output, sum(losses)
|
||||||
|
|
||||||
def _begin_update(self, states, tokvecs, drop=0.):
|
def _begin_update(self, states, tokvecs, drop=0.):
|
||||||
|
@ -293,39 +286,35 @@ cdef class Parser:
|
||||||
tokvecs = self.model.ops.allocate((len(states), n_tokens, vector_length), dtype='f')
|
tokvecs = self.model.ops.allocate((len(states), n_tokens, vector_length), dtype='f')
|
||||||
for i, state in enumerate(states):
|
for i, state in enumerate(states):
|
||||||
state.set_context_tokens(cpu_tokens[i], nF, nB, nS, nL, nR)
|
state.set_context_tokens(cpu_tokens[i], nF, nB, nS, nL, nR)
|
||||||
#state.set_attributes(features[i], tokens[i], attr_names)
|
|
||||||
gpu_tokens = self.model.ops.xp.array(cpu_tokens)
|
|
||||||
for i in range(len(states)):
|
for i in range(len(states)):
|
||||||
tokvecs[i] = all_tokvecs[i][gpu_tokens[i]]
|
for j, tok_i in enumerate(cpu_tokens[i]):
|
||||||
tokvecs *= (gpu_tokens >= 0).reshape((gpu_tokens.shape[0], gpu_tokens.shape[1], 1))
|
if tok_i >= 0:
|
||||||
return (gpu_tokens, self.model.ops.asarray(features), tokvecs)
|
tokvecs[i, j] = all_tokvecs[i][tok_i]
|
||||||
|
return (cpu_tokens, self.model.ops.asarray(features), tokvecs)
|
||||||
|
|
||||||
def _validate_batch(self, is_valid, states):
|
def _validate_batch(self, int[:, ::1] is_valid, states):
|
||||||
cdef StateClass state
|
cdef StateClass state
|
||||||
cdef int i
|
cdef int i
|
||||||
cdef int[:, :] is_valid_cpu = is_valid.get()
|
|
||||||
for i, state in enumerate(states):
|
for i, state in enumerate(states):
|
||||||
self.moves.set_valid(&is_valid_cpu[i, 0], state.c)
|
self.moves.set_valid(&is_valid[i, 0], state.c)
|
||||||
is_valid.set(numpy.asarray(is_valid_cpu))
|
|
||||||
|
|
||||||
def _cost_batch(self, costs, is_valid,
|
def _cost_batch(self, float[:, ::1] costs, int[:, ::1] is_valid,
|
||||||
states, golds):
|
states, golds):
|
||||||
cdef int i
|
cdef int i
|
||||||
cdef StateClass state
|
cdef StateClass state
|
||||||
cdef GoldParse gold
|
cdef GoldParse gold
|
||||||
cdef int[:, :] is_valid_cpu = is_valid.get()
|
|
||||||
cdef weight_t[:, :] costs_cpu = costs.get()
|
|
||||||
|
|
||||||
for i, (state, gold) in enumerate(zip(states, golds)):
|
for i, (state, gold) in enumerate(zip(states, golds)):
|
||||||
self.moves.set_costs(&is_valid_cpu[i, 0], &costs_cpu[i, 0], state, gold)
|
self.moves.set_costs(&is_valid[i, 0], &costs[i, 0], state, gold)
|
||||||
is_valid.set(numpy.asarray(is_valid_cpu))
|
|
||||||
costs.set(numpy.asarray(costs_cpu))
|
|
||||||
|
|
||||||
def _transition_batch(self, states, scores):
|
def _transition_batch(self, states, docs, scores):
|
||||||
cdef StateClass state
|
cdef StateClass state
|
||||||
cdef int guess
|
cdef int guess
|
||||||
for state, guess in zip(states, scores.argmax(axis=1)):
|
for state, doc, guess in zip(states, docs, scores.argmax(axis=1)):
|
||||||
action = self.moves.c[guess]
|
action = self.moves.c[guess]
|
||||||
|
orths = [t.lex.orth for t in state.c._sent[:state.c.length]]
|
||||||
|
words = [doc.vocab.strings[w] for w in orths]
|
||||||
|
if not action.is_valid(state.c, action.label):
|
||||||
|
ValueError("Invalid action", scores)
|
||||||
action.do(state.c, action.label)
|
action.do(state.c, action.label)
|
||||||
|
|
||||||
def _set_gradient(self, gradients, scores, is_valid, costs):
|
def _set_gradient(self, gradients, scores, is_valid, costs):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user