* Add some of the sensetagger changes

This commit is contained in:
Matthew Honnibal 2015-07-03 05:18:15 +02:00
parent b7e9c1da85
commit 1be5ab200f

View File

@ -6,7 +6,7 @@ from .structs cimport TokenC
from .strings cimport StringStore from .strings cimport StringStore
from .tokens cimport Tokens from .tokens cimport Tokens
from ._ml cimport Model from ._ml cimport Model
from .senses cimport POS_SENSES, N_SENSES from .senses cimport POS_SENSES, N_SENSES, encode_sense_strs
from .gold cimport GoldParse from .gold cimport GoldParse
@ -198,7 +198,7 @@ cdef class SenseTagger:
cdef int i cdef int i
for i in range(tokens.length): for i in range(tokens.length):
n_valid = self._set_valid(<bint*>eg.c.is_valid, pos_senses(&tokens.data[i])) n_valid = self._set_valid(<bint*>eg.c.is_valid, pos_senses(&tokens.data[i]))
if n_valid >= 2: if n_valid >= 1:
fill_context(eg.c.atoms, &tokens.data[i]) fill_context(eg.c.atoms, &tokens.data[i])
self.model.predict(eg) self.model.predict(eg)
tokens.data[i].sense = eg.c.guess tokens.data[i].sense = eg.c.guess
@ -223,16 +223,29 @@ cdef class SenseTagger:
cdef int _set_valid(self, bint* is_valid, flags_t senses) except -1: cdef int _set_valid(self, bint* is_valid, flags_t senses) except -1:
cdef int n_valid cdef int n_valid
cdef flags_t bit cdef flags_t bit
for bit in range(N_SENSES): is_valid[0] = False
for bit in range(1, N_SENSES):
is_valid[bit] = senses & (1 << bit) is_valid[bit] = senses & (1 << bit)
n_valid += is_valid[bit] n_valid += is_valid[bit]
return n_valid return n_valid
cdef int _set_costs(self, bint* is_valid, int* costs, flags_t senses): cdef int _set_costs(self, bint* is_valid, int* costs, flags_t senses):
cdef flags_t bit cdef flags_t bit
for bit in range(N_SENSES): is_valid[0] = False
costs[0] = 1
for bit in range(1, N_SENSES):
is_valid[bit] = True is_valid[bit] = True
costs[bit] = 0 if senses & (1 << bit) else 1 costs[bit] = 0 if (senses & (1 << bit)) else 1
cdef flags_t pos_senses(const TokenC* token) nogil: cdef flags_t pos_senses(const TokenC* token) nogil:
return token.lex.senses & POS_SENSES[<int>token.pos] return token.lex.senses & POS_SENSES[<int>token.pos]
cdef list _set_bits(flags_t flags):
bits = []
cdef flags_t bit
for bit in range(N_SENSES):
if flags & (1 << bit):
bits.append(bit)
return bits