mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 18:26:30 +03:00
Refactor model for beam parser, to avoid conditionals on model type
This commit is contained in:
parent
a57f337d29
commit
6b912731f8
|
@ -22,6 +22,8 @@ from ._parse_features cimport fill_context
|
|||
from ._parse_features cimport CONTEXT_SIZE
|
||||
from ._parse_features cimport fill_context
|
||||
from ._parse_features cimport *
|
||||
from .transition_system cimport TransitionSystem
|
||||
from ..tokens.doc cimport Doc
|
||||
|
||||
|
||||
cdef class ParserPerceptron(AveragedPerceptron):
|
||||
|
@ -51,6 +53,23 @@ cdef class ParserPerceptron(AveragedPerceptron):
|
|||
fill_context(eg.atoms, state)
|
||||
eg.nr_feat = self.extracter.set_features(eg.features, eg.atoms)
|
||||
|
||||
def _update_from_history(self, TransitionSystem moves, Doc doc, history, weight_t grad):
|
||||
cdef Pool mem = Pool()
|
||||
cdef atom_t[CONTEXT_SIZE] context
|
||||
features = <FeatureC*>mem.alloc(self.nr_feat, sizeof(FeatureC))
|
||||
|
||||
cdef StateClass stcls = StateClass.init(doc.c, doc.length)
|
||||
moves.initialize_state(stcls.c)
|
||||
|
||||
cdef class_t clas
|
||||
self.time += 1
|
||||
for clas in history:
|
||||
fill_context(context, stcls.c)
|
||||
nr_feat = self.extracter.set_features(features, context)
|
||||
for feat in features[:nr_feat]:
|
||||
self.update_weight(feat.key, clas, feat.value * grad)
|
||||
moves.c[clas].do(stcls.c, moves.c[clas].label)
|
||||
|
||||
|
||||
cdef class ParserNeuralNet(NeuralNet):
|
||||
def __init__(self, shape, **kwargs):
|
||||
|
@ -123,6 +142,41 @@ cdef class ParserNeuralNet(NeuralNet):
|
|||
cdef void _softmaxC(self, weight_t* out) nogil:
|
||||
pass
|
||||
|
||||
def _update_from_history(self, TransitionSystem moves, Doc doc, history, weight_t grad):
|
||||
cdef Example py_eg = Example(nr_class=moves.n_moves, nr_atom=CONTEXT_SIZE,
|
||||
nr_feat=self.nr_feat, widths=self.widths)
|
||||
stcls = StateClass.init(doc.c, doc.length)
|
||||
moves.initialize_state(stcls.c)
|
||||
cdef uint64_t[2] key
|
||||
key[0] = hash64(doc.c, sizeof(TokenC) * doc.length, 0)
|
||||
key[1] = 0
|
||||
cdef uint64_t clas
|
||||
for clas in history:
|
||||
self.set_featuresC(py_eg.c, stcls.c)
|
||||
moves.set_valid(py_eg.c.is_valid, stcls.c)
|
||||
# Update with a sparse gradient: everything's 0, except our class.
|
||||
# Remember, this is a component of the global update. It's not our
|
||||
# "job" here to think about the other beam candidates. We just want
|
||||
# to work on this sequence. However, other beam candidates will
|
||||
# have gradients that refer to the same state.
|
||||
# We therefore have a key that indicates the current sequence, so that
|
||||
# the model can merge updates that refer to the same state together,
|
||||
# by summing their gradients.
|
||||
memset(py_eg.c.costs, 0, self.moves.n_moves)
|
||||
py_eg.c.costs[clas] = grad
|
||||
self.updateC(
|
||||
py_eg.c.features, py_eg.c.nr_feat, True, py_eg.c.costs, py_eg.c.is_valid,
|
||||
False, key=key[0])
|
||||
moves.c[clas].do(stcls.c, self.moves.c[clas].label)
|
||||
py_eg.c.reset()
|
||||
# Build a hash of the state sequence.
|
||||
# Position 0 represents the previous sequence, position 1 the new class.
|
||||
# So we want to do:
|
||||
# key.prev = hash((key.prev, key.new))
|
||||
# key.new = clas
|
||||
key[1] = clas
|
||||
key[0] = hash64(key, sizeof(key), 0)
|
||||
|
||||
|
||||
cdef inline FeatureC* _add_token(FeatureC* feats,
|
||||
int slot, const TokenC* token, weight_t value) nogil:
|
||||
|
|
|
@ -113,33 +113,12 @@ cdef class BeamParser(Parser):
|
|||
break
|
||||
else:
|
||||
violn.check_crf(pred, gold)
|
||||
if isinstance(self.model, ParserNeuralNet):
|
||||
min_grad = 0.1 ** (itn+1)
|
||||
for grad, hist in zip(violn.p_probs, violn.p_hist):
|
||||
assert not math.isnan(grad) and not math.isinf(grad)
|
||||
if abs(grad) >= min_grad:
|
||||
self._update_dense(tokens, hist, grad)
|
||||
for grad, hist in zip(violn.g_probs, violn.g_hist):
|
||||
assert not math.isnan(grad) and not math.isinf(grad)
|
||||
if abs(grad) >= min_grad:
|
||||
self._update_dense(tokens, hist, grad)
|
||||
else:
|
||||
self.model.time += 1
|
||||
#min_grad = 0.01 ** (itn+1)
|
||||
#for grad, hist in zip(violn.p_probs, violn.p_hist):
|
||||
# assert not math.isnan(grad)
|
||||
# assert not math.isinf(grad)
|
||||
# if abs(grad) >= min_grad:
|
||||
# self._update(tokens, hist, -grad)
|
||||
#for grad, hist in zip(violn.g_probs, violn.g_hist):
|
||||
# assert not math.isnan(grad)
|
||||
# assert not math.isinf(grad)
|
||||
# if abs(grad) >= min_grad:
|
||||
# self._update(tokens, hist, -grad)
|
||||
if violn.p_hist:
|
||||
self._update(tokens, violn.p_hist[0], -1.0)
|
||||
if violn.g_hist:
|
||||
self._update(tokens, violn.g_hist[0], 1.0)
|
||||
min_grad = 0.1 ** (itn+1)
|
||||
histories = zip(violn.p_probs, violn.p_hist) + zip(violn.g_probs, violn.g_hist)
|
||||
for grad, hist in histories:
|
||||
assert not math.isnan(grad) and not math.isinf(grad)
|
||||
if abs(grad) >= min_grad:
|
||||
self._update_from_history(self.moves, tokens, hist, grad)
|
||||
_cleanup(pred)
|
||||
_cleanup(gold)
|
||||
return pred.loss
|
||||
|
@ -149,16 +128,11 @@ cdef class BeamParser(Parser):
|
|||
nr_feat=self.model.nr_feat, widths=self.model.widths)
|
||||
cdef ExampleC* eg = py_eg.c
|
||||
|
||||
cdef ParserNeuralNet nn_model
|
||||
cdef ParserPerceptron ap_model
|
||||
for i in range(beam.size):
|
||||
py_eg.reset()
|
||||
stcls = <StateClass>beam.at(i)
|
||||
if not stcls.c.is_final():
|
||||
if isinstance(self.model, ParserNeuralNet):
|
||||
ParserNeuralNet.set_featuresC(self.model, eg, stcls.c)
|
||||
else:
|
||||
ParserPerceptron.set_featuresC(self.model, eg, stcls.c)
|
||||
self.model.set_featuresC(eg, stcls.c)
|
||||
self.model.set_scoresC(beam.scores[i], eg.features, eg.nr_feat, 1)
|
||||
self.moves.set_valid(beam.is_valid[i], stcls.c)
|
||||
if gold is not None:
|
||||
|
@ -173,59 +147,6 @@ cdef class BeamParser(Parser):
|
|||
beam.advance(_transition_state, _hash_state, <void*>self.moves.c)
|
||||
beam.check_done(_check_final_state, NULL)
|
||||
|
||||
def _update_dense(self, Doc doc, history, weight_t loss):
|
||||
cdef Example py_eg = Example(nr_class=self.moves.n_moves, nr_atom=CONTEXT_SIZE,
|
||||
nr_feat=self.model.nr_feat, widths=self.model.widths)
|
||||
cdef ExampleC* eg = py_eg.c
|
||||
cdef ParserNeuralNet model = self.model
|
||||
stcls = StateClass.init(doc.c, doc.length)
|
||||
self.moves.initialize_state(stcls.c)
|
||||
cdef uint64_t[2] key
|
||||
key[0] = hash64(doc.c, sizeof(TokenC) * doc.length, 0)
|
||||
key[1] = 0
|
||||
cdef uint64_t clas
|
||||
for clas in history:
|
||||
model.set_featuresC(eg, stcls.c)
|
||||
self.moves.set_valid(eg.is_valid, stcls.c)
|
||||
# Update with a sparse gradient: everything's 0, except our class.
|
||||
# Remember, this is a component of the global update. It's not our
|
||||
# "job" here to think about the other beam candidates. We just want
|
||||
# to work on this sequence. However, other beam candidates will
|
||||
# have gradients that refer to the same state.
|
||||
# We therefore have a key that indicates the current sequence, so that
|
||||
# the model can merge updates that refer to the same state together,
|
||||
# by summing their gradients.
|
||||
memset(eg.costs, 0, self.moves.n_moves)
|
||||
eg.costs[clas] = loss
|
||||
model.updateC(
|
||||
eg.features, eg.nr_feat, True, eg.costs, eg.is_valid, False, key=key[0])
|
||||
self.moves.c[clas].do(stcls.c, self.moves.c[clas].label)
|
||||
py_eg.reset()
|
||||
# Build a hash of the state sequence.
|
||||
# Position 0 represents the previous sequence, position 1 the new class.
|
||||
# So we want to do:
|
||||
# key.prev = hash((key.prev, key.new))
|
||||
# key.new = clas
|
||||
key[1] = clas
|
||||
key[0] = hash64(key, sizeof(key), 0)
|
||||
|
||||
def _update(self, Doc tokens, list hist, weight_t inc):
|
||||
cdef Pool mem = Pool()
|
||||
cdef atom_t[CONTEXT_SIZE] context
|
||||
features = <FeatureC*>mem.alloc(self.model.nr_feat, sizeof(FeatureC))
|
||||
|
||||
cdef StateClass stcls = StateClass.init(tokens.c, tokens.length)
|
||||
self.moves.initialize_state(stcls.c)
|
||||
|
||||
cdef class_t clas
|
||||
cdef ParserPerceptron model = self.model
|
||||
for clas in hist:
|
||||
fill_context(context, stcls.c)
|
||||
nr_feat = model.extracter.set_features(features, context)
|
||||
for feat in features[:nr_feat]:
|
||||
model.update_weight(feat.key, clas, feat.value * inc)
|
||||
self.moves.c[clas].do(stcls.c, self.moves.c[clas].label)
|
||||
|
||||
|
||||
# These are passed as callbacks to thinc.search.Beam
|
||||
cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1:
|
||||
|
@ -261,32 +182,3 @@ def _cleanup(Beam beam):
|
|||
cdef hash_t _hash_state(void* _state, void* _) except 0:
|
||||
state = <StateClass>_state
|
||||
return state.c.hash()
|
||||
|
||||
|
||||
# def _early_update(self, Doc doc, Beam pred, Beam gold):
|
||||
# # Gather the partition function --- Z --- by which we can normalize the
|
||||
# # scores into a probability distribution. The simple idea here is that
|
||||
# # we clip the probability of all parses outside the beam to 0.
|
||||
# cdef long double Z = 0.0
|
||||
# for i in range(pred.size):
|
||||
# # Make sure we've only got negative examples here.
|
||||
# # Otherwise, we might double-count the gold.
|
||||
# if pred._states[i].loss > 0:
|
||||
# Z += exp(pred._states[i].score)
|
||||
# cdef weight_t grad
|
||||
# if Z > 0: # If no negative examples, don't update.
|
||||
# Z += exp(gold.score)
|
||||
# for i, hist in enumerate(pred.histories):
|
||||
# if pred._states[i].loss > 0:
|
||||
# # Update with the negative example.
|
||||
# # Gradient of loss is P(parse) - 0
|
||||
# grad = exp(pred._states[i].score) / Z
|
||||
# if abs(grad) >= 0.01:
|
||||
# self._update_dense(doc, hist, grad)
|
||||
# # Update with the positive example.
|
||||
# # Gradient of loss is P(parse) - 1
|
||||
# grad = (exp(gold.score) / Z) - 1
|
||||
# if abs(grad) >= 0.01:
|
||||
# self._update_dense(doc, gold.histories[0], grad)
|
||||
#
|
||||
#
|
||||
|
|
Loading…
Reference in New Issue
Block a user