Work on beam parser, with max violation

This commit is contained in:
Matthew Honnibal 2016-07-24 14:26:52 +02:00
parent a1281835a8
commit 0bf448461e

View File

@ -1,5 +1,6 @@
# cython: profile=True # cython: profile=True
# cython: experimental_cpp_class_def=True # cython: experimental_cpp_class_def=True
# cython: cdivision=True
""" """
MALT-style dependency parser MALT-style dependency parser
""" """
@ -11,7 +12,7 @@ from cpython.ref cimport PyObject, Py_INCREF, Py_XDECREF
from libc.stdint cimport uint32_t, uint64_t from libc.stdint cimport uint32_t, uint64_t
from libc.string cimport memset, memcpy from libc.string cimport memset, memcpy
from libc.stdlib cimport rand from libc.stdlib cimport rand
from libc.math cimport log, exp from libc.math cimport log, exp, isnan, isinf
import random import random
import os.path import os.path
from os import path from os import path
@ -67,7 +68,7 @@ def get_templates(name):
cdef int BEAM_WIDTH = 8 cdef int BEAM_WIDTH = 8
MAX_VIOLN_UPDATE = False
cdef class BeamParser(Parser): cdef class BeamParser(Parser):
cdef public int beam_width cdef public int beam_width
@ -91,7 +92,7 @@ cdef class BeamParser(Parser):
tokens[i] = state.c._sent[i] tokens[i] = state.c._sent[i]
_cleanup(beam) _cleanup(beam)
def train(self, Doc tokens, GoldParse gold_parse): def train(self, Doc tokens, GoldParse gold_parse, itn=0):
self.moves.preprocess_gold(gold_parse) self.moves.preprocess_gold(gold_parse)
cdef Beam pred = Beam(self.moves.n_moves, self.beam_width) cdef Beam pred = Beam(self.moves.n_moves, self.beam_width)
pred.initialize(_init_state, tokens.length, tokens.c) pred.initialize(_init_state, tokens.length, tokens.c)
@ -100,6 +101,7 @@ cdef class BeamParser(Parser):
cdef Beam gold = Beam(self.moves.n_moves, self.beam_width) cdef Beam gold = Beam(self.moves.n_moves, self.beam_width)
gold.initialize(_init_state, tokens.length, tokens.c) gold.initialize(_init_state, tokens.length, tokens.c)
gold.check_done(_check_final_state, NULL) gold.check_done(_check_final_state, NULL)
violn = MaxViolation()
while not pred.is_done and not gold.is_done: while not pred.is_done and not gold.is_done:
# We search separately here, to allow for ambiguity in the gold # We search separately here, to allow for ambiguity in the gold
# parse. # parse.
@ -107,19 +109,33 @@ cdef class BeamParser(Parser):
self._advance_beam(gold, gold_parse, True) self._advance_beam(gold, gold_parse, True)
if MAX_VIOLN_UPDATE: if MAX_VIOLN_UPDATE:
violn.check_crf(pred, gold) violn.check_crf(pred, gold)
if violn.delta >= 10000:
break
elif pred.min_score > gold.score: # Early update elif pred.min_score > gold.score: # Early update
break break
cdef long double Z = 0.0
if MAX_VIOLN_UPDATE: if MAX_VIOLN_UPDATE:
if violn.delta != -1: self._max_violation_update(
for prob, hist in zip(violn.p_scores, violn.p_hist): tokens, violn.p_probs, violn.p_hist,
self._update_dense(tokens, hist, prob / violn.Z) violn.g_probs, violn.g_hist)
for prob, hist in zip(violn.g_scores, violn.g_hist):
self._update_dense(tokens, hist, -prob / violn.gZ)
else: else:
self._early_update(tokens, pred, gold)
_cleanup(pred)
_cleanup(gold)
return pred.loss
def _max_violation_update(self, Doc doc, p_grads, p_hist, g_grads, g_hist):
for grad, hist in zip(p_grads, p_hist):
if abs(grad) >= 1e-5:
self._update_dense(doc, hist, grad)
for grad, hist in zip(g_grads, g_hist):
if abs(grad) >= 1e-5:
self._update_dense(doc, hist, -grad)
def _early_update(self, Doc doc, Beam pred, Beam gold):
# Gather the partition function --- Z --- by which we can normalize the # Gather the partition function --- Z --- by which we can normalize the
# scores into a probability distribution. The simple idea here is that # scores into a probability distribution. The simple idea here is that
# we clip the probability of all parses outside the beam to 0. # we clip the probability of all parses outside the beam to 0.
cdef long double Z = 0.0
for i in range(pred.size): for i in range(pred.size):
# Make sure we've only got negative examples here. # Make sure we've only got negative examples here.
# Otherwise, we might double-count the gold. # Otherwise, we might double-count the gold.
@ -131,13 +147,10 @@ cdef class BeamParser(Parser):
if pred._states[i].loss > 0: if pred._states[i].loss > 0:
# Update with the negative example. # Update with the negative example.
# Gradient of loss is P(parse) - 0 # Gradient of loss is P(parse) - 0
self._update_dense(tokens, hist, exp(pred._states[i].score) / Z) self._update_dense(doc, hist, exp(pred._states[i].score) / Z)
# Update with the positive example. # Update with the positive example.
# Gradient of loss is P(parse) - 1 # Gradient of loss is P(parse) - 1
self._update_dense(tokens, gold.histories[0], (exp(gold.score) / Z) - 1) self._update_dense(doc, gold.histories[0], (exp(gold.score) / Z) - 1)
_cleanup(pred)
_cleanup(gold)
return pred.loss
def _advance_beam(self, Beam beam, GoldParse gold, bint follow_gold): def _advance_beam(self, Beam beam, GoldParse gold, bint follow_gold):
cdef Example py_eg = Example(nr_class=self.moves.n_moves, nr_atom=CONTEXT_SIZE, cdef Example py_eg = Example(nr_class=self.moves.n_moves, nr_atom=CONTEXT_SIZE,