* Refactor _advance_beam function

This commit is contained in:
Matthew Honnibal 2015-06-02 18:38:41 +02:00
parent 0786d9b3c7
commit d1b55310a1

View File

@ -1,9 +1,11 @@
# cython: profile=True
"""
MALT-style dependency parser
"""
from __future__ import unicode_literals
cimport cython
from libc.stdint cimport uint32_t, uint64_t
from libc.string cimport memset, memcpy
import random
import os.path
from os import path
@ -152,11 +154,11 @@ cdef class Parser:
self._advance_beam(gold, gold_parse, True)
violn.check(pred, gold)
counts = {}
if pred._states[0].loss >= 1:
if pred.loss >= 1:
self._count_feats(counts, tokens, violn.g_hist, 1)
self._count_feats(counts, tokens, violn.p_hist, -1)
self.model._model.update(counts)
return pred._states[0].loss
return pred.loss
def _advance_beam(self, Beam beam, GoldParse gold, bint follow_gold):
cdef atom_t[CONTEXT_SIZE] context
@ -167,22 +169,26 @@ cdef class Parser:
for i in range(beam.size):
state = <State*>beam.at(i)
fill_context(context, state)
scores = self.model.score(context)
validities = self.moves.get_valid(state)
if gold is None:
for j in range(self.moves.n_moves):
beam.set_cell(i, j, scores[j], validities[j], 0)
elif not follow_gold:
self.model.set_scores(beam.scores[i], context)
self.moves.set_valid(beam.is_valid[i], state)
if follow_gold:
for i in range(beam.size):
state = <State*>beam.at(i)
for j in range(self.moves.n_moves):
move = &self.moves.c[j]
cost = move.get_cost(move, state, gold)
beam.set_cell(i, j, scores[j], validities[j], cost)
else:
beam.costs[i][j] = move.get_cost(move, state, gold)
beam.is_valid[i][j] = beam.costs[i][j] == 0
elif gold is not None:
for i in range(beam.size):
state = <State*>beam.at(i)
for j in range(self.moves.n_moves):
move = &self.moves.c[j]
cost = move.get_cost(move, state, gold)
beam.set_cell(i, j, scores[j], cost == 0, cost)
beam.costs[i][j] = move.get_cost(move, state, gold)
beam.advance(_transition_state, <void*>self.moves.c)
state = <State*>beam.at(0)
if state.sent[state.i].sent_end:
beam.size = int(beam.size / 2)
beam.check_done(_check_final_state, NULL)
def _count_feats(self, dict counts, Tokens tokens, list hist, int inc):