* Work on implementing a trainable cache for the parser. So far, doesn't improve efficiency

This commit is contained in:
Matthew Honnibal 2014-12-19 09:30:50 +11:00
parent 033d6c9ac2
commit 53b8bc1f3c
4 changed files with 53 additions and 6 deletions

View File

@ -31,6 +31,7 @@ cdef class DecisionMemory:
cdef Pool mem
cdef PreshCounter _counts
cdef PreshCounter _class_counts
cdef PreshMap memos
cdef list class_names
cdef int inc(self, hash_t context_key, hash_t clas, count_t inc) except -1

View File

@ -1,6 +1,6 @@
"""Create a term-document matrix"""
cimport cython
from libc.stdint cimport int64_t
from libc.string cimport memmove
from cymem.cymem cimport Address
@ -42,6 +42,20 @@ cdef class DecisionMemory:
self.mem = Pool()
self._counts = PreshCounter()
self._class_counts = PreshCounter()
self.memos = PreshMap()
def load(self, loc, thresh=0):
cdef:
count_t freq
hash_t key
int clas
for line in open(loc):
freq, key, clas = [int(p) for p in line.split()]
if thresh == 0 or freq >= thresh:
self.memos.set(key, <void*>(clas+1))
def get(self, hash_t context_key):
return <int64_t>self.memos.get(context_key) - 1
def __getitem__(self, ids):
cdef id_t[2] context

View File

@ -1,9 +1,12 @@
from libc.stdint cimport uint32_t, uint64_t
from thinc.features cimport Extractor
from thinc.learner cimport LinearModel
from .arc_eager cimport TransitionSystem
from ..tokens cimport Tokens, TokenC
from ._state cimport State
from ..index cimport DecisionMemory
cdef class GreedyParser:
@ -11,5 +14,6 @@ cdef class GreedyParser:
cdef Extractor extractor
cdef readonly LinearModel model
cdef TransitionSystem moves
cdef readonly DecisionMemory guess_cache
cpdef int parse(self, Tokens tokens) except -1

View File

@ -4,6 +4,7 @@ MALT-style dependency parser
"""
from __future__ import unicode_literals
cimport cython
from libc.stdint cimport uint32_t, uint64_t
import random
import os.path
from os.path import join as pjoin
@ -11,6 +12,7 @@ import shutil
import json
from cymem.cymem cimport Pool, Address
from murmurhash.mrmr cimport hash64
from thinc.typedefs cimport weight_t, class_t, feat_t, atom_t
@ -26,7 +28,7 @@ from ..tokens cimport Tokens, TokenC
from .arc_eager cimport TransitionSystem, Transition
from ._state cimport init_state, State, is_final, get_idx, get_s0, get_s1
from ._state cimport init_state, State, is_final, get_idx, get_s0, get_s1, get_n0, get_n1
from . import _parse_features
from ._parse_features cimport fill_context, CONTEXT_SIZE
@ -63,24 +65,39 @@ cdef class GreedyParser:
self.extractor = Extractor(get_templates(self.cfg.features))
self.moves = TransitionSystem(self.cfg.left_labels, self.cfg.right_labels)
self.model = LinearModel(self.moves.n_moves, self.extractor.n_templ)
# Classes for decision memory
classes = ['S', 'D']
classes += ['L-%s' % label for label in self.cfg.left_labels]
classes += ['R-%s' % label for label in self.cfg.right_labels]
self.guess_cache = DecisionMemory(classes)
if os.path.exists(pjoin(model_dir, 'model')):
self.model.load(pjoin(model_dir, 'model'))
if os.path.exists(pjoin(model_dir, 'guess_cache')):
self.guess_cache.load(pjoin(model_dir, 'guess_cache'))
cpdef int parse(self, Tokens tokens) except -1:
cdef:
const Feature* feats
const weight_t* scores
Transition guess
uint64_t state_key
cdef atom_t[CONTEXT_SIZE] context
cdef int n_feats
cdef Pool mem = Pool()
cdef State* state = init_state(mem, tokens.data, tokens.length)
cdef int guess_clas
while not is_final(state):
fill_context(context, state)
feats = self.extractor.get_feats(context, &n_feats)
scores = self.model.get_scores(feats, n_feats)
guess = self.moves.best_valid(scores, state)
state_key = _approx_hash_state(state)
guess_clas = self.guess_cache.get(state_key)
if guess_clas == -1:
fill_context(context, state)
feats = self.extractor.get_feats(context, &n_feats)
scores = self.model.get_scores(feats, n_feats)
guess = self.moves.best_valid(scores, state)
self.guess_cache.inc(state_key, guess.clas, 1)
else:
guess = self.moves._moves[guess_clas]
self.moves.transition(state, &guess)
return 0
@ -117,6 +134,17 @@ cdef class GreedyParser:
return n_corr
cdef uint64_t _approx_hash_state(const State* state) except 0:
cdef int[3] context
context[0] = get_s0(state).lex.sic
context[1] = get_n0(state).lex.sic
if get_n1(state):
context[2] = get_n1(state).pos
else:
context[2] = 0
return hash64(context, sizeof(int) * 3, 0)
cdef dict _get_counts(int guess, int best, const Feature* feats, const int n_feats,
int inc):
if guess == best: