Work on beam parser, with max violation

This commit is contained in:
Matthew Honnibal 2016-07-24 14:26:52 +02:00
parent a1281835a8
commit 0bf448461e

View File

@ -1,5 +1,6 @@
# cython: profile=True
# cython: experimental_cpp_class_def=True
# cython: cdivision=True
"""
MALT-style dependency parser
"""
@ -11,7 +12,7 @@ from cpython.ref cimport PyObject, Py_INCREF, Py_XDECREF
from libc.stdint cimport uint32_t, uint64_t
from libc.string cimport memset, memcpy
from libc.stdlib cimport rand
from libc.math cimport log, exp
from libc.math cimport log, exp, isnan, isinf
import random
import os.path
from os import path
@ -67,7 +68,7 @@ def get_templates(name):
cdef int BEAM_WIDTH = 8
MAX_VIOLN_UPDATE = False
cdef class BeamParser(Parser):
cdef public int beam_width
@ -91,7 +92,7 @@ cdef class BeamParser(Parser):
tokens[i] = state.c._sent[i]
_cleanup(beam)
def train(self, Doc tokens, GoldParse gold_parse):
def train(self, Doc tokens, GoldParse gold_parse, itn=0):
self.moves.preprocess_gold(gold_parse)
cdef Beam pred = Beam(self.moves.n_moves, self.beam_width)
pred.initialize(_init_state, tokens.length, tokens.c)
@ -100,6 +101,7 @@ cdef class BeamParser(Parser):
cdef Beam gold = Beam(self.moves.n_moves, self.beam_width)
gold.initialize(_init_state, tokens.length, tokens.c)
gold.check_done(_check_final_state, NULL)
violn = MaxViolation()
while not pred.is_done and not gold.is_done:
# We search separately here, to allow for ambiguity in the gold
# parse.
@ -107,38 +109,49 @@ cdef class BeamParser(Parser):
self._advance_beam(gold, gold_parse, True)
if MAX_VIOLN_UPDATE:
violn.check_crf(pred, gold)
if violn.delta >= 10000:
break
elif pred.min_score > gold.score: # Early update
break
cdef long double Z = 0.0
if MAX_VIOLN_UPDATE:
if violn.delta != -1:
for prob, hist in zip(violn.p_scores, violn.p_hist):
self._update_dense(tokens, hist, prob / violn.Z)
for prob, hist in zip(violn.g_scores, violn.g_hist):
self._update_dense(tokens, hist, -prob / violn.gZ)
self._max_violation_update(
tokens, violn.p_probs, violn.p_hist,
violn.g_probs, violn.g_hist)
else:
# Gather the partition function --- Z --- by which we can normalize the
# scores into a probability distribution. The simple idea here is that
# we clip the probability of all parses outside the beam to 0.
for i in range(pred.size):
# Make sure we've only got negative examples here.
# Otherwise, we might double-count the gold.
if pred._states[i].loss > 0:
Z += exp(pred._states[i].score)
if Z > 0: # If no negative examples, don't update.
Z += exp(gold.score)
for i, hist in enumerate(pred.histories):
if pred._states[i].loss > 0:
# Update with the negative example.
# Gradient of loss is P(parse) - 0
self._update_dense(tokens, hist, exp(pred._states[i].score) / Z)
# Update with the positive example.
# Gradient of loss is P(parse) - 1
self._update_dense(tokens, gold.histories[0], (exp(gold.score) / Z) - 1)
self._early_update(tokens, pred, gold)
_cleanup(pred)
_cleanup(gold)
return pred.loss
def _max_violation_update(self, Doc doc, p_grads, p_hist, g_grads, g_hist):
for grad, hist in zip(p_grads, p_hist):
if abs(grad) >= 1e-5:
self._update_dense(doc, hist, grad)
for grad, hist in zip(g_grads, g_hist):
if abs(grad) >= 1e-5:
self._update_dense(doc, hist, -grad)
def _early_update(self, Doc doc, Beam pred, Beam gold):
# Gather the partition function --- Z --- by which we can normalize the
# scores into a probability distribution. The simple idea here is that
# we clip the probability of all parses outside the beam to 0.
cdef long double Z = 0.0
for i in range(pred.size):
# Make sure we've only got negative examples here.
# Otherwise, we might double-count the gold.
if pred._states[i].loss > 0:
Z += exp(pred._states[i].score)
if Z > 0: # If no negative examples, don't update.
Z += exp(gold.score)
for i, hist in enumerate(pred.histories):
if pred._states[i].loss > 0:
# Update with the negative example.
# Gradient of loss is P(parse) - 0
self._update_dense(doc, hist, exp(pred._states[i].score) / Z)
# Update with the positive example.
# Gradient of loss is P(parse) - 1
self._update_dense(doc, gold.histories[0], (exp(gold.score) / Z) - 1)
def _advance_beam(self, Beam beam, GoldParse gold, bint follow_gold):
cdef Example py_eg = Example(nr_class=self.moves.n_moves, nr_atom=CONTEXT_SIZE,
nr_feat=self.model.nr_feat, widths=self.model.widths)