* Tagger now gets 97pc on wsj, parsing 19-21 in 500ms. Gets 92.7 on web text.

This commit is contained in:
Matthew Honnibal 2014-10-22 12:57:06 +11:00
parent 0a0e41f6c8
commit ad49e2482e

View File

@ -1,9 +1,11 @@
# cython: profile=True
from os import path from os import path
import os import os
import shutil import shutil
import ujson import ujson
import random import random
import codecs import codecs
import gzip
from thinc.weights cimport arg_max from thinc.weights cimport arg_max
@ -11,9 +13,9 @@ from thinc.features import NonZeroConjFeat
from thinc.features import ConjFeat from thinc.features import ConjFeat
from .en import EN from .en import EN
from .lexeme import LexStr_shape, LexStr_suff, LexStr_pre, LexStr_norm from .lexeme cimport LexStr_shape, LexStr_suff, LexStr_pre, LexStr_norm
from .lexeme import LexDist_upper, LexDist_title from .lexeme cimport LexDist_upper, LexDist_title
from .lexeme import LexDist_upper, LexInt_cluster, LexInt_id from .lexeme cimport LexDist_upper, LexInt_cluster, LexInt_id
NULL_TAG = 0 NULL_TAG = 0
@ -31,7 +33,7 @@ cdef class Tagger:
self._scores = <weight_t*>self.mem.alloc(len(self.tags), sizeof(weight_t)) self._scores = <weight_t*>self.mem.alloc(len(self.tags), sizeof(weight_t))
self._guess = NULL_TAG self._guess = NULL_TAG
if path.exists(path.join(model_dir, 'model.gz')): if path.exists(path.join(model_dir, 'model.gz')):
with open(path.join(model_dir, 'model.gz'), 'r') as file_: with gzip.open(path.join(model_dir, 'model.gz'), 'r') as file_:
self.model.load(file_) self.model.load(file_)
cpdef class_t predict(self, int i, Tokens tokens, class_t prev, class_t prev_prev) except 0: cpdef class_t predict(self, int i, Tokens tokens, class_t prev, class_t prev_prev) except 0:
@ -58,51 +60,53 @@ cdef class Tagger:
return cls.tags[tag] return cls.tags[tag]
cpdef enum: cpdef enum:
P2i P2i
P2c
P2shape
P2suff
P2pref
P2w
P2oft_title
P2oft_upper
P1i P1i
P1c
P1shape
P1suff
P1pref
P1w
P1oft_title
P1oft_upper
N0i N0i
N0c
N0shape
N0suff
N0pref
N0w
N0oft_title
N0oft_upper
N1i N1i
N1c
N1shape
N1suff
N1pref
N1w
N1oft_title
N1oft_upper
N2i N2i
P2c
P1c
N0c
N1c
N2c N2c
P2shape
P1shape
N0shape
N1shape
N2shape N2shape
P2suff
P1suff
N0suff
N1suff
N2suff N2suff
P2pref
P1pref
N0pref
N1pref
N2pref N2pref
P2w
P1w
N0w
N1w
N2w N2w
P2oft_title
P1oft_title
N0oft_title
N1oft_title
N2oft_title N2oft_title
P2oft_upper
P1oft_upper
N0oft_upper
N1oft_upper
N2oft_upper N2oft_upper
P1t P1t
@ -115,55 +119,58 @@ cdef int get_atoms(atom_t* context, int i, Tokens tokens, class_t prev_tag,
cdef int j cdef int j
for j in range(CONTEXT_SIZE): for j in range(CONTEXT_SIZE):
context[j] = 0 context[j] = 0
indices = [i-2, i-1, i, i+1, i+2] cdef int[5] indices
ints = tokens.int_array(indices, [LexInt_id, LexInt_cluster]) indices[0] = i-2
flags = tokens.bool_array(indices, [LexDist_title, LexDist_upper]) indices[1] = i-1
strings = tokens.string_hash_array(indices, [LexStr_shape, LexStr_suff, indices[2] = i
LexStr_pre, LexStr_norm]) indices[3] = i+1
_fill_token(&context[P2i], flags[0], ints[0], strings[0]) indices[4] = i+2
_fill_token(&context[P1i], flags[1], ints[1], strings[1])
_fill_token(&context[N0i], flags[2], ints[2], strings[2]) cdef int[2] int_feats
_fill_token(&context[N1i], flags[3], ints[3], strings[3]) int_feats[0] = <int>LexInt_id
_fill_token(&context[N2i], flags[4], ints[4], strings[4]) int_feats[1] = <int>LexInt_cluster
cdef int[4] string_feats
string_feats[0] = <int>LexStr_shape
string_feats[1] = <int>LexStr_suff
string_feats[2] = <int>LexStr_pre
string_feats[3] = <int>LexStr_norm
cdef int[2] bool_feats
bool_feats[0] = <int>LexDist_title
bool_feats[1] = <int>LexDist_upper
cdef int c = 0
c = tokens.int_array(context, c, indices, 5, int_feats, 2)
c = tokens.string_array(context, c, indices, 5, string_feats, 4)
c = tokens.bool_array(context, c, indices, 5, bool_feats, 2)
context[P1t] = prev_tag context[P1t] = prev_tag
context[P2t] = prev_prev_tag context[P2t] = prev_prev_tag
cdef int _fill_token(atom_t* c, flags, ints, strings) except -1:
cdef int i = 0
c[i] = ints[0]; i += 1
c[i] = ints[1]; i += 1
c[i] = strings[0]; i += 1
c[i] = strings[1]; i += 1
c[i] = strings[2]; i += 1
c[i] = strings[3]; i += 1
c[i] = flags[0]; i += 1
c[i] = flags[1]; i += 1
TEMPLATES = ( TEMPLATES = (
(N0i,), (N0i,),
#(N0w,), (N0w,),
#(N0suff,), (N0suff,),
#(N0pref,), (N0pref,),
(P1t,), (P1t,),
(P2t,), (P2t,),
#(P1t, P2t), (P1t, P2t),
#(P1t, N0w), (P1t, N0w),
#(P1w,), (P1w,),
#(P1suff,), (P1suff,),
#(P2w,), (P2w,),
#(N1w,), (N1w,),
#(N1suff,), (N1suff,),
#(N2w,), (N2w,),
#(N0shape,), (N0shape,),
#(N0c,), (N0c,),
#(N1c,), (N1c,),
#(N2c,), (N2c,),
#(P1c,), (P1c,),
#(P2c,), (P2c,),
#(N0oft_upper,), (N0oft_upper,),
#(N0oft_title,), (N0oft_title,),
) )