* Add count_tags functionto pos.pyx, which should probably live in another file. Feature set achieves 97.9 on wsj19-21, 95.85 on onto web.

This commit is contained in:
Matthew Honnibal 2014-10-31 17:42:04 +11:00
parent 63114820cf
commit f67cb9a5a3

View File

@ -6,15 +6,17 @@ import ujson
import random
import codecs
import gzip
import cython
from libc.stdint cimport uint32_t
from thinc.weights cimport arg_max
from thinc.features import NonZeroConjFeat
from thinc.features import ConjFeat
from .en import EN
from .lexeme cimport *
from .lang cimport Lexicon
NULL_TAG = 0
@ -39,11 +41,9 @@ cdef class Tagger:
self._guess = NULL_TAG
cpdef class_t predict(self, int i, Tokens tokens, class_t prev, class_t prev_prev) except 0:
assert i >= 0
get_atoms(self._atoms, tokens.lex[i-2], tokens.lex[i-1], tokens.lex[i],
tokens.lex[i+1], tokens.lex[i+2], prev, prev_prev)
self.extractor.extract(self._feats, self._values, self._atoms, NULL)
assert self._feats[self.extractor.n] == 0
self._guess = self.model.score(self._scores, self._feats, self._values)
return self._guess
@ -64,6 +64,21 @@ cdef class Tagger:
return cls.tags[tag]
@cython.boundscheck(False)
def count_tags(Tagger tagger, Tokens tokens, uint32_t[:, :] tag_counts):
cdef class_t prev_prev, prev, tag
prev = tagger.tags['EOL']; prev_prev = tagger.tags['EOL']
cdef int i
cdef id_t token
for i in range(tokens.length):
tag = tagger.predict(i, tokens, prev, prev_prev)
prev_prev = prev
prev = tag
token = tokens.lex[i].id
if token < tag_counts.shape[0]:
tag_counts[token, tag] += 1
cpdef enum:
P2i
P2c
@ -73,6 +88,7 @@ cpdef enum:
P2suff
P2oft_title
P2oft_upper
P2pos
P1i
P1c
@ -82,6 +98,7 @@ cpdef enum:
P1suff
P1oft_title
P1oft_upper
P1pos
N0i
N0c
@ -91,6 +108,7 @@ cpdef enum:
N0suff
N0oft_title
N0oft_upper
N0pos
N1i
N1c
@ -100,6 +118,7 @@ cpdef enum:
N1suff
N1oft_title
N1oft_upper
N1pos
N2i
N2c
@ -109,6 +128,7 @@ cpdef enum:
N2suff
N2oft_title
N2oft_upper
N2pos
P2t
P1t
@ -137,6 +157,7 @@ cdef inline void _fill_token(atom_t* atoms, Lexeme* lex) nogil:
atoms[6] = lex.flags & (1 << OFT_TITLE)
atoms[7] = lex.flags & (1 << OFT_UPPER)
atoms[8] = lex.postype
TEMPLATES = (
@ -163,4 +184,15 @@ TEMPLATES = (
(P2c,),
(N0oft_upper,),
(N0oft_title,),
(P1t, N1w),
(P1t, P2t, N1w),
(P1w, P2w, N1w),
(P2w, N1w, N2w),
(N0pos,),
(N0w, N1pos),
(N0w, N1pos, N2pos),
(P1t, N0pos),
(P2t, P1t, N0pos)
)