WIP on rewrite parser

This commit is contained in:
Matthew Honnibal 2021-01-25 23:20:30 +11:00
parent cda3b08dd1
commit b456929bfd
2 changed files with 70 additions and 118 deletions

View File

@ -350,9 +350,9 @@ cdef class Begin:
elif st.B_(1).ent_iob == 3:
# If the next word is B, we can't B now
return False
elif st.B_(1).sent_start == 1:
# Don't allow entities to extend across sentence boundaries
return False
#elif st.B_(1).sent_start == 1:
# # Don't allow entities to extend across sentence boundaries
# return False
# Don't allow entities to start on whitespace
elif Lexeme.get_struct_attr(st.B_(0).lex, IS_SPACE):
return False
@ -418,9 +418,9 @@ cdef class In:
# Otherwise, force acceptance, even if we're across a sentence
# boundary or the token is whitespace.
return True
elif st.B(1) != -1 and st.B_(1).sent_start == 1:
# Don't allow entities to extend across sentence boundaries
return False
#elif st.B(1) != -1 and st.B_(1).sent_start == 1:
# # Don't allow entities to extend across sentence boundaries
# return False
else:
return True

View File

@ -10,7 +10,7 @@ import random
from typing import Optional
import srsly
from thinc.api import set_dropout_rate, CupyOps
from thinc.api import set_dropout_rate, CupyOps, get_array_module
from thinc.extra.search cimport Beam
import numpy.random
import numpy
@ -338,58 +338,79 @@ cdef class Parser(TrainablePipe):
losses=losses,
beam_density=self.cfg["beam_density"]
)
model, backprop_tok2vec = self.model.begin_update([eg.x for eg in examples])
final_states = self.moves.init_batch([eg.x for eg in examples])
self._predict_states(model, final_states)
histories = [list(state.history) for state in final_states]
#oracle_histories = [self.moves.get_oracle_sequence(eg) for eg in examples]
max_moves = self.cfg["update_with_oracle_cut_size"]
if max_moves >= 1:
# Chop sequences into lengths of this many words, to make the
# batch uniform length.
max_moves = int(random.uniform(max_moves // 2, max_moves * 2))
states, golds, _ = self._init_gold_batch(
docs = [eg.x for eg in examples]
model, backprop_tok2vec = self.model.begin_update(docs)
states = self.moves.init_batch(docs)
self._predict_states(states)
# I've separated the prediction from getting the batch because
# I like the idea of trying to store the histories or maybe compute
# them in another process or something. Just walking the states
# and transitioning isn't expensive anyway.
ids, costs = self._get_ids_and_costs_from_histories(
examples,
histories,
max_length=max_moves
[list(state.history) for state in states]
)
else:
states, golds, _ = self.moves.init_gold_batch(examples)
if not states:
return losses
all_states = list(states)
states_golds = list(zip(states, golds))
n_moves = 0
while states_golds:
states, golds = zip(*states_golds)
scores, backprop = model.begin_update(states)
d_scores = self.get_batch_loss(states, golds, scores, losses)
# Note that the gradient isn't normalized by the batch size
# here, because our "samples" are really the states...But we
# can't normalize by the number of states either, as then we'd
# be getting smaller gradients for states in long sequences.
backprop(d_scores)
# Follow the predicted action
self.transition_states(states, scores)
states_golds = [(s, g) for (s, g) in zip(states, golds) if not s.is_final()]
if max_moves >= 1 and n_moves >= max_moves:
break
n_moves += 1
backprop_tok2vec(golds)
scores, backprop_states = model.begin_update(ids)
d_scores = self.get_loss(scores, costs)
d_tokvecs = backprop_states(d_scores)
backprop_tok2vec(d_tokvecs)
if sgd not in (None, False):
self.finish_update(sgd)
self.set_annotations([eg.x for eg in examples], final_states)
self.set_annotations(docs, states)
losses[self.name] += (d_scores**2).sum()
# Ugh, this is annoying. If we're working on GPU, we want to free the
# memory ASAP. It seems that Python doesn't necessarily get around to
# removing these in time if we don't explicitly delete? It's confusing.
del backprop
del backprop_states
del backprop_tok2vec
model.clear_memory()
del model
return losses
def _get_ids_and_costs_from_histories(self, examples, histories):
cdef StateClass state
cdef int clas
cdef int nF = self.model.state2vec.nF
cdef int nO = self.moves.n_moves
cdef int nS = sum([len(history) for history in histories])
# ids and costs have one row per state in the whole batch.
cdef np.ndarray ids = numpy.zeros((nS, nF), dtype="i")
cdef np.ndarray costs = numpy.zeros((nS, nO), dtype="f")
cdef Pool mem = Pool()
is_valid = <int*>mem.alloc(nO, sizeof(int))
c_ids = <int*>ids.data
c_costs = <float*>costs.data
states = self.moves.init_states([eg.x for eg in examples])
cdef int i = 0
for eg, state, history in zip(examples, states, histories):
gold = self.moves.init_gold(state, eg)
for clas in history:
# Set a row into the C data of the arrays (which we return)
state.c.set_context_tokens(&c_ids[i*nF], nF)
self.moves.set_costs(is_valid, &c_costs[i*nO], state.c, gold)
action = self.moves.c[clas]
action.do(state.c, action.label)
state.c.history.push_back(clas)
i += 1
# If the model is on GPU, copy the costs to device.
costs = self.model.ops.asarray(costs)
return ids, costs
def get_loss(self, scores, costs):
xp = get_array_module(scores)
best_costs = costs.min(axis=1, keepdims=True)
is_gold = costs <= costs.min(axis=1, keepdims=True)
gscores = scores[is_gold]
max_ = scores.max(axis=1)
gmax = gscores.max(axis=1, keepdims=True)
exp_scores = xp.exp(scores - max_)
exp_gscores = xp.exp(gscores - gmax)
Z = exp_scores.sum(axis=1, keepdims=True)
gZ = exp_gscores.sum(axis=1, keepdims=True)
d_scores = exp_scores / Z
d_scores[is_gold] -= exp_gscores / gZ
return d_scores
def rehearse(self, examples, sgd=None, losses=None, **cfg):
"""Perform a "rehearsal" update, to prevent catastrophic forgetting."""
if losses is None:
@ -460,36 +481,6 @@ cdef class Parser(TrainablePipe):
if sgd is not None:
self.finish_update(sgd)
def get_batch_loss(self, states, golds, float[:, ::1] scores, losses):
cdef StateClass state
cdef Pool mem = Pool()
cdef int i
# n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc
assert self.moves.n_moves > 0, Errors.E924.format(name=self.name)
is_valid = <int*>mem.alloc(self.moves.n_moves, sizeof(int))
costs = <float*>mem.alloc(self.moves.n_moves, sizeof(float))
cdef np.ndarray d_scores = numpy.zeros((len(states), self.moves.n_moves),
dtype='f', order='C')
c_d_scores = <float*>d_scores.data
unseen_classes = self.model.attrs["unseen_classes"]
for i, (state, gold) in enumerate(zip(states, golds)):
memset(is_valid, 0, self.moves.n_moves * sizeof(int))
memset(costs, 0, self.moves.n_moves * sizeof(float))
self.moves.set_costs(is_valid, costs, state.c, gold)
for j in range(self.moves.n_moves):
if costs[j] <= 0.0 and j in unseen_classes:
unseen_classes.remove(j)
cpu_log_loss(c_d_scores,
costs, is_valid, &scores[i, 0], d_scores.shape[1])
c_d_scores += d_scores.shape[1]
# Note that we don't normalize this. See comment in update() for why.
if losses is not None:
losses.setdefault(self.name, 0.)
losses[self.name] += (d_scores**2).sum()
return d_scores
def set_output(self, nO):
self.model.attrs["resize_output"](self.model, nO)
@ -586,42 +577,3 @@ cdef class Parser(TrainablePipe):
except AttributeError:
raise ValueError(Errors.E149) from None
return self
def _init_gold_batch(self, examples, oracle_histories, max_length):
"""Make a square batch, of length equal to the shortest transition
sequence or a cap. A long
doc will get multiple states. Let's say we have a doc of length 2*N,
where N is the shortest doc. We'll make two states, one representing
long_doc[:N], and another representing long_doc[N:]."""
cdef:
StateClass start_state
StateClass state
Transition action
all_states = self.moves.init_batch([eg.predicted for eg in examples])
assert len(all_states) == len(examples) == len(oracle_histories)
states = []
golds = []
for state, eg, history in zip(all_states, examples, oracle_histories):
if not history:
continue
if not self.moves.has_gold(eg):
continue
gold = self.moves.init_gold(state, eg)
if len(history) < max_length:
states.append(state)
golds.append(gold)
continue
for i in range(0, len(history), max_length):
if state.is_final():
break
start_state = state.copy()
for clas in history[i:i+max_length]:
action = self.moves.c[clas]
action.do(state.c, action.label)
state.c.history.push_back(clas)
if state.is_final():
break
if self.moves.has_gold(eg, start_state.B(0), state.B(0)):
states.append(start_state)
golds.append(gold)
return states, golds, max_length