Fix handling of added labels. Resolves #3189

This commit is contained in:
Matthew Honnibal 2019-02-24 16:41:41 +01:00
parent 4dc57d9e15
commit 0367f864fe
3 changed files with 54 additions and 28 deletions

View File

@ -19,6 +19,7 @@ cdef struct WeightsC:
const float* feat_bias const float* feat_bias
const float* hidden_bias const float* hidden_bias
const float* hidden_weights const float* hidden_weights
const float* seen_classes
cdef struct ActivationsC: cdef struct ActivationsC:

View File

@ -44,8 +44,10 @@ cdef WeightsC get_c_weights(model) except *:
output.feat_bias = <const float*>state2vec.bias.data output.feat_bias = <const float*>state2vec.bias.data
cdef np.ndarray vec2scores_W = model.vec2scores.W cdef np.ndarray vec2scores_W = model.vec2scores.W
cdef np.ndarray vec2scores_b = model.vec2scores.b cdef np.ndarray vec2scores_b = model.vec2scores.b
cdef np.ndarray class_mask = model._class_mask
output.hidden_weights = <const float*>vec2scores_W.data output.hidden_weights = <const float*>vec2scores_W.data
output.hidden_bias = <const float*>vec2scores_b.data output.hidden_bias = <const float*>vec2scores_b.data
output.seen_classes = <const float*>class_mask.data
return output return output
@ -115,6 +117,16 @@ cdef void predict_states(ActivationsC* A, StateC** states,
for i in range(n.states): for i in range(n.states):
VecVec.add_i(&A.scores[i*n.classes], VecVec.add_i(&A.scores[i*n.classes],
W.hidden_bias, 1., n.classes) W.hidden_bias, 1., n.classes)
# Set unseen classes to minimum value
i = 0
min_ = A.scores[0]
for i in range(1, n.states * n.classes):
if A.scores[i] < min_:
min_ = A.scores[i]
for i in range(n.states):
for j in range(n.classes):
if not W.seen_classes[j]:
A.scores[i*n.classes+j] = min_
cdef void sum_state_features(float* output, cdef void sum_state_features(float* output,
@ -189,12 +201,17 @@ cdef int arg_max_if_valid(const weight_t* scores, const int* is_valid, int n) no
class ParserModel(Model): class ParserModel(Model):
def __init__(self, tok2vec, lower_model, upper_model): def __init__(self, tok2vec, lower_model, upper_model, unseen_classes=None):
Model.__init__(self) Model.__init__(self)
self._layers = [tok2vec, lower_model, upper_model] self._layers = [tok2vec, lower_model, upper_model]
self.unseen_classes = set()
if unseen_classes:
for class_ in unseen_classes:
self.unseen_classes.add(class_)
def begin_update(self, docs, drop=0.): def begin_update(self, docs, drop=0.):
step_model = ParserStepModel(docs, self._layers, drop=drop) step_model = ParserStepModel(docs, self._layers, drop=drop,
unseen_classes=self.unseen_classes)
def finish_parser_update(golds, sgd=None): def finish_parser_update(golds, sgd=None):
step_model.make_updates(sgd) step_model.make_updates(sgd)
return None return None
@ -207,9 +224,8 @@ class ParserModel(Model):
with Model.use_device('cpu'): with Model.use_device('cpu'):
larger = Affine(new_output, smaller.nI) larger = Affine(new_output, smaller.nI)
# Set nan as value for unseen classes, to prevent prediction. larger.W.fill(0.0)
larger.W.fill(self.ops.xp.nan) larger.b.fill(0.0)
larger.b.fill(self.ops.xp.nan)
# It seems very unhappy if I pass these as smaller.W? # It seems very unhappy if I pass these as smaller.W?
# Seems to segfault. Maybe it's a descriptor protocol thing? # Seems to segfault. Maybe it's a descriptor protocol thing?
smaller_W = smaller.W smaller_W = smaller.W
@ -221,6 +237,8 @@ class ParserModel(Model):
larger_W[:smaller.nO] = smaller_W larger_W[:smaller.nO] = smaller_W
larger_b[:smaller.nO] = smaller_b larger_b[:smaller.nO] = smaller_b
self._layers[-1] = larger self._layers[-1] = larger
for i in range(smaller.nO, new_output):
self.unseen_classes.add(i)
def begin_training(self, X, y=None): def begin_training(self, X, y=None):
self.lower.begin_training(X, y=y) self.lower.begin_training(X, y=y)
@ -239,18 +257,32 @@ class ParserModel(Model):
class ParserStepModel(Model): class ParserStepModel(Model):
def __init__(self, docs, layers, drop=0.): def __init__(self, docs, layers, unseen_classes=None, drop=0.):
self.tokvecs, self.bp_tokvecs = layers[0].begin_update(docs, drop=drop) self.tokvecs, self.bp_tokvecs = layers[0].begin_update(docs, drop=drop)
self.state2vec = precompute_hiddens(len(docs), self.tokvecs, layers[1], self.state2vec = precompute_hiddens(len(docs), self.tokvecs, layers[1],
drop=drop) drop=drop)
self.vec2scores = layers[-1] self.vec2scores = layers[-1]
self.cuda_stream = util.get_cuda_stream() self.cuda_stream = util.get_cuda_stream()
self.backprops = [] self.backprops = []
self._class_mask = numpy.zeros((self.vec2scores.nO,), dtype='f')
self._class_mask.fill(1)
if unseen_classes is not None:
for class_ in unseen_classes:
self._class_mask[class_] = 0.
@property @property
def nO(self): def nO(self):
return self.state2vec.nO return self.state2vec.nO
def class_is_unseen(self, class_):
return self._class_mask[class_]
def mark_class_unseen(self, class_):
self._class_mask[class_] = 0
def mark_class_seen(self, class_):
self._class_mask[class_] = 1
def begin_update(self, states, drop=0.): def begin_update(self, states, drop=0.):
token_ids = self.get_token_ids(states) token_ids = self.get_token_ids(states)
vector, get_d_tokvecs = self.state2vec.begin_update(token_ids, drop=0.0) vector, get_d_tokvecs = self.state2vec.begin_update(token_ids, drop=0.0)
@ -258,24 +290,12 @@ class ParserStepModel(Model):
if mask is not None: if mask is not None:
vector *= mask vector *= mask
scores, get_d_vector = self.vec2scores.begin_update(vector, drop=drop) scores, get_d_vector = self.vec2scores.begin_update(vector, drop=drop)
# We can have nans from unseen classes. # If the class is unseen, make sure its score is minimum
# For backprop purposes, we want to treat unseen classes as having the scores[:, self._class_mask == 0] = numpy.nanmin(scores)
# lowest score.
# numpy's nan_to_num function doesn't take a value, and nan is replaced
# by 0...-inf is replaced by minimum, so we go via that. Ugly to the max.
# Note that scores is always a numpy array! Should fix #3112
scores[numpy.isnan(scores)] = -numpy.inf
numpy.nan_to_num(scores, copy=False)
def backprop_parser_step(d_scores, sgd=None): def backprop_parser_step(d_scores, sgd=None):
# If we have a non-zero gradient for a previously unseen class, # Zero vectors for unseen classes
# replace the weight with 0. d_scores *= self._class_mask
new_classes = self.vec2scores.ops.xp.logical_and(
self.vec2scores.ops.xp.isnan(self.vec2scores.b),
d_scores.any(axis=0)
)
self.vec2scores.b[new_classes] = 0.
self.vec2scores.W[new_classes] = 0.
d_vector = get_d_vector(d_scores, sgd=sgd) d_vector = get_d_vector(d_scores, sgd=sgd)
if mask is not None: if mask is not None:
d_vector *= mask d_vector *= mask

View File

@ -163,6 +163,8 @@ cdef class Parser:
added = self.moves.add_action(action, label) added = self.moves.add_action(action, label)
if added: if added:
resized = True resized = True
if resized:
self.cfg["nr_class"] = self.moves.n_moves
if self.model not in (True, False, None) and resized: if self.model not in (True, False, None) and resized:
self.model.resize_output(self.moves.n_moves) self.model.resize_output(self.moves.n_moves)
@ -435,22 +437,22 @@ cdef class Parser:
if self._rehearsal_model is None: if self._rehearsal_model is None:
return None return None
losses.setdefault(self.name, 0.) losses.setdefault(self.name, 0.)
states = self.moves.init_batch(docs) states = self.moves.init_batch(docs)
# This is pretty dirty, but the NER can resize itself in init_batch, # This is pretty dirty, but the NER can resize itself in init_batch,
# if labels are missing. We therefore have to check whether we need to # if labels are missing. We therefore have to check whether we need to
# expand our model output. # expand our model output.
self.model.resize_output(self.moves.n_moves) self.model.resize_output(self.moves.n_moves)
self._rehearsal_model.resize_output(self.moves.n_moves)
# Prepare the stepwise model, and get the callback for finishing the batch # Prepare the stepwise model, and get the callback for finishing the batch
tutor = self._rehearsal_model(docs) tutor, _ = self._rehearsal_model.begin_update(docs, drop=0.0)
model, finish_update = self.model.begin_update(docs, drop=0.0) model, finish_update = self.model.begin_update(docs, drop=0.0)
n_scores = 0. n_scores = 0.
loss = 0. loss = 0.
non_zeroed_classes = self._rehearsal_model.upper.W.any(axis=1)
while states: while states:
targets, _ = tutor.begin_update(states) targets, _ = tutor.begin_update(states, drop=0.)
guesses, backprop = model.begin_update(states) guesses, backprop = model.begin_update(states, drop=0.)
d_scores = (targets - guesses) / targets.shape[0] d_scores = (guesses - targets) / targets.shape[0]
d_scores *= non_zeroed_classes
# If all weights for an output are 0 in the original model, don't # If all weights for an output are 0 in the original model, don't
# supervise that output. This allows us to add classes. # supervise that output. This allows us to add classes.
loss += (d_scores**2).sum() loss += (d_scores**2).sum()
@ -543,6 +545,9 @@ cdef class Parser:
memset(is_valid, 0, self.moves.n_moves * sizeof(int)) memset(is_valid, 0, self.moves.n_moves * sizeof(int))
memset(costs, 0, self.moves.n_moves * sizeof(float)) memset(costs, 0, self.moves.n_moves * sizeof(float))
self.moves.set_costs(is_valid, costs, state, gold) self.moves.set_costs(is_valid, costs, state, gold)
for j in range(self.moves.n_moves):
if costs[j] <= 0.0 and j in self.model.unseen_classes:
self.model.unseen_classes.remove(j)
cpu_log_loss(c_d_scores, cpu_log_loss(c_d_scores,
costs, is_valid, &scores[i, 0], d_scores.shape[1]) costs, is_valid, &scores[i, 0], d_scores.shape[1])
c_d_scores += d_scores.shape[1] c_d_scores += d_scores.shape[1]