* Work on implementing a trainable cache for the parser. So far, doesn't improve efficiency

This commit is contained in:
Matthew Honnibal 2014-12-19 09:30:50 +11:00
parent 033d6c9ac2
commit 53b8bc1f3c
4 changed files with 53 additions and 6 deletions

View File

@ -31,6 +31,7 @@ cdef class DecisionMemory:
cdef Pool mem cdef Pool mem
cdef PreshCounter _counts cdef PreshCounter _counts
cdef PreshCounter _class_counts cdef PreshCounter _class_counts
cdef PreshMap memos
cdef list class_names cdef list class_names
cdef int inc(self, hash_t context_key, hash_t clas, count_t inc) except -1 cdef int inc(self, hash_t context_key, hash_t clas, count_t inc) except -1

View File

@ -1,6 +1,6 @@
"""Create a term-document matrix""" """Create a term-document matrix"""
cimport cython cimport cython
from libc.stdint cimport int64_t
from libc.string cimport memmove from libc.string cimport memmove
from cymem.cymem cimport Address from cymem.cymem cimport Address
@ -42,6 +42,20 @@ cdef class DecisionMemory:
self.mem = Pool() self.mem = Pool()
self._counts = PreshCounter() self._counts = PreshCounter()
self._class_counts = PreshCounter() self._class_counts = PreshCounter()
self.memos = PreshMap()
def load(self, loc, thresh=0):
cdef:
count_t freq
hash_t key
int clas
for line in open(loc):
freq, key, clas = [int(p) for p in line.split()]
if thresh == 0 or freq >= thresh:
self.memos.set(key, <void*>(clas+1))
def get(self, hash_t context_key):
return <int64_t>self.memos.get(context_key) - 1
def __getitem__(self, ids): def __getitem__(self, ids):
cdef id_t[2] context cdef id_t[2] context

View File

@ -1,9 +1,12 @@
from libc.stdint cimport uint32_t, uint64_t
from thinc.features cimport Extractor from thinc.features cimport Extractor
from thinc.learner cimport LinearModel from thinc.learner cimport LinearModel
from .arc_eager cimport TransitionSystem from .arc_eager cimport TransitionSystem
from ..tokens cimport Tokens, TokenC from ..tokens cimport Tokens, TokenC
from ._state cimport State
from ..index cimport DecisionMemory
cdef class GreedyParser: cdef class GreedyParser:
@ -11,5 +14,6 @@ cdef class GreedyParser:
cdef Extractor extractor cdef Extractor extractor
cdef readonly LinearModel model cdef readonly LinearModel model
cdef TransitionSystem moves cdef TransitionSystem moves
cdef readonly DecisionMemory guess_cache
cpdef int parse(self, Tokens tokens) except -1 cpdef int parse(self, Tokens tokens) except -1

View File

@ -4,6 +4,7 @@ MALT-style dependency parser
""" """
from __future__ import unicode_literals from __future__ import unicode_literals
cimport cython cimport cython
from libc.stdint cimport uint32_t, uint64_t
import random import random
import os.path import os.path
from os.path import join as pjoin from os.path import join as pjoin
@ -11,6 +12,7 @@ import shutil
import json import json
from cymem.cymem cimport Pool, Address from cymem.cymem cimport Pool, Address
from murmurhash.mrmr cimport hash64
from thinc.typedefs cimport weight_t, class_t, feat_t, atom_t from thinc.typedefs cimport weight_t, class_t, feat_t, atom_t
@ -26,7 +28,7 @@ from ..tokens cimport Tokens, TokenC
from .arc_eager cimport TransitionSystem, Transition from .arc_eager cimport TransitionSystem, Transition
from ._state cimport init_state, State, is_final, get_idx, get_s0, get_s1 from ._state cimport init_state, State, is_final, get_idx, get_s0, get_s1, get_n0, get_n1
from . import _parse_features from . import _parse_features
from ._parse_features cimport fill_context, CONTEXT_SIZE from ._parse_features cimport fill_context, CONTEXT_SIZE
@ -63,24 +65,39 @@ cdef class GreedyParser:
self.extractor = Extractor(get_templates(self.cfg.features)) self.extractor = Extractor(get_templates(self.cfg.features))
self.moves = TransitionSystem(self.cfg.left_labels, self.cfg.right_labels) self.moves = TransitionSystem(self.cfg.left_labels, self.cfg.right_labels)
self.model = LinearModel(self.moves.n_moves, self.extractor.n_templ) self.model = LinearModel(self.moves.n_moves, self.extractor.n_templ)
# Classes for decision memory
classes = ['S', 'D']
classes += ['L-%s' % label for label in self.cfg.left_labels]
classes += ['R-%s' % label for label in self.cfg.right_labels]
self.guess_cache = DecisionMemory(classes)
if os.path.exists(pjoin(model_dir, 'model')): if os.path.exists(pjoin(model_dir, 'model')):
self.model.load(pjoin(model_dir, 'model')) self.model.load(pjoin(model_dir, 'model'))
if os.path.exists(pjoin(model_dir, 'guess_cache')):
self.guess_cache.load(pjoin(model_dir, 'guess_cache'))
cpdef int parse(self, Tokens tokens) except -1: cpdef int parse(self, Tokens tokens) except -1:
cdef: cdef:
const Feature* feats const Feature* feats
const weight_t* scores const weight_t* scores
Transition guess Transition guess
uint64_t state_key
cdef atom_t[CONTEXT_SIZE] context cdef atom_t[CONTEXT_SIZE] context
cdef int n_feats cdef int n_feats
cdef Pool mem = Pool() cdef Pool mem = Pool()
cdef State* state = init_state(mem, tokens.data, tokens.length) cdef State* state = init_state(mem, tokens.data, tokens.length)
cdef int guess_clas
while not is_final(state): while not is_final(state):
fill_context(context, state) state_key = _approx_hash_state(state)
feats = self.extractor.get_feats(context, &n_feats) guess_clas = self.guess_cache.get(state_key)
scores = self.model.get_scores(feats, n_feats) if guess_clas == -1:
guess = self.moves.best_valid(scores, state) fill_context(context, state)
feats = self.extractor.get_feats(context, &n_feats)
scores = self.model.get_scores(feats, n_feats)
guess = self.moves.best_valid(scores, state)
self.guess_cache.inc(state_key, guess.clas, 1)
else:
guess = self.moves._moves[guess_clas]
self.moves.transition(state, &guess) self.moves.transition(state, &guess)
return 0 return 0
@ -117,6 +134,17 @@ cdef class GreedyParser:
return n_corr return n_corr
cdef uint64_t _approx_hash_state(const State* state) except 0:
cdef int[3] context
context[0] = get_s0(state).lex.sic
context[1] = get_n0(state).lex.sic
if get_n1(state):
context[2] = get_n1(state).pos
else:
context[2] = 0
return hash64(context, sizeof(int) * 3, 0)
cdef dict _get_counts(int guess, int best, const Feature* feats, const int n_feats, cdef dict _get_counts(int guess, int best, const Feature* feats, const int n_feats,
int inc): int inc):
if guess == best: if guess == best: