* Fix POS tagger, so that it loads correctly. Lexemes are being read in.

This commit is contained in:
Matthew Honnibal 2014-10-30 13:38:55 +11:00
parent 67c8c8019f
commit 889b7b48b4
2 changed files with 10 additions and 9 deletions

View File

@ -300,6 +300,7 @@ cdef class Lexicon:
assert fp != NULL assert fp != NULL
cdef size_t st cdef size_t st
cdef Lexeme* lexeme cdef Lexeme* lexeme
i = 0
while True: while True:
lexeme = <Lexeme*>self.mem.alloc(sizeof(Lexeme), 1) lexeme = <Lexeme*>self.mem.alloc(sizeof(Lexeme), 1)
st = fread(lexeme, sizeof(Lexeme), 1, fp) st = fread(lexeme, sizeof(Lexeme), 1, fp)
@ -307,6 +308,8 @@ cdef class Lexicon:
break break
self.lexemes.push_back(lexeme) self.lexemes.push_back(lexeme)
self._dict.set(lexeme.hash, lexeme) self._dict.set(lexeme.hash, lexeme)
i += 1
print "Load %d lexemes" % i
fclose(fp) fclose(fp)

View File

@ -24,21 +24,19 @@ cdef class Tagger:
tags = {'NULL': NULL_TAG} tags = {'NULL': NULL_TAG}
def __init__(self, model_dir): def __init__(self, model_dir):
self.mem = Pool() self.mem = Pool()
self.extractor = Extractor(TEMPLATES, [ConjFeat for _ in TEMPLATES]) tags_loc = path.join(model_dir, 'postags.json')
if path.exists(tags_loc):
with open(tags_loc) as file_:
Tagger.tags.update(ujson.load(file_))
self.model = LinearModel(len(self.tags), self.extractor.n) self.model = LinearModel(len(self.tags), self.extractor.n)
if path.exists(path.join(model_dir, 'model')):
self.model.load(path.join(model_dir, 'model'))
self.extractor = Extractor(TEMPLATES, [ConjFeat for _ in TEMPLATES])
self._atoms = <atom_t*>self.mem.alloc(CONTEXT_SIZE, sizeof(atom_t)) self._atoms = <atom_t*>self.mem.alloc(CONTEXT_SIZE, sizeof(atom_t))
self._feats = <feat_t*>self.mem.alloc(self.extractor.n+1, sizeof(feat_t)) self._feats = <feat_t*>self.mem.alloc(self.extractor.n+1, sizeof(feat_t))
self._values = <weight_t*>self.mem.alloc(self.extractor.n+1, sizeof(weight_t)) self._values = <weight_t*>self.mem.alloc(self.extractor.n+1, sizeof(weight_t))
self._scores = <weight_t*>self.mem.alloc(len(self.tags), sizeof(weight_t)) self._scores = <weight_t*>self.mem.alloc(len(self.tags), sizeof(weight_t))
self._guess = NULL_TAG self._guess = NULL_TAG
if path.exists(path.join(model_dir, 'model')):
self.model.load(path.join(model_dir, 'model'))
tags_loc = path.join(model_dir, 'postags.json')
if path.exists(tags_loc):
with open(tags_loc) as file_:
Tagger.tags.update(ujson.load(file_))
if path.exists(path.join(model_dir, 'strings')):
EN.lexicon.strings.load(path.join(model_dir, 'strings'))
cpdef class_t predict(self, int i, Tokens tokens, class_t prev, class_t prev_prev) except 0: cpdef class_t predict(self, int i, Tokens tokens, class_t prev, class_t prev_prev) except 0:
assert i >= 0 assert i >= 0