diff --git a/spacy/sense_tagger.pyx b/spacy/sense_tagger.pyx index 82ddf93e6..82e07710a 100644 --- a/spacy/sense_tagger.pyx +++ b/spacy/sense_tagger.pyx @@ -1,13 +1,17 @@ -from thinc.api cimport Example -from thinc.typedefs cimport atom_t - from .typedefs cimport flags_t from .structs cimport TokenC from .strings cimport StringStore from .tokens cimport Tokens -from ._ml cimport Model from .senses cimport POS_SENSES, N_SENSES, encode_sense_strs from .gold cimport GoldParse +from .parts_of_speech cimport NOUN, VERB + +from thinc.learner cimport LinearModel +from thinc.features cimport Extractor + +from thinc.typedefs cimport atom_t, weight_t, feat_t + +from os import path @@ -173,6 +177,8 @@ cdef int fill_token(atom_t* ctxt, const TokenC* token) except -1: cdef int fill_context(atom_t* ctxt, const TokenC* token) except -1: + # NB: we have padding to keep us safe here + # See tokens.pyx fill_token(&ctxt[P2W], token - 2) fill_token(&ctxt[P1W], token - 1) @@ -185,62 +191,79 @@ cdef int fill_context(atom_t* ctxt, const TokenC* token) except -1: cdef class SenseTagger: cdef readonly StringStore strings - cdef readonly Model model + cdef readonly LinearModel model + cdef readonly Extractor extractor + cdef readonly model_dir def __init__(self, StringStore strings, model_dir): - self.strings = strings + if model_dir is not None and path.isdir(model_dir): + model_dir = path.join(model_dir, 'model') + templates = unigrams + bigrams + trigrams - self.model = Model(N_SENSES, templates, model_dir) + self.extractor = Extractor(templates) + self.model = LinearModel(N_SENSES, self.extractor.n_templ) + self.model_dir = model_dir + if self.model_dir and path.exists(self.model_dir): + self.model.load(self.model_dir, freq_thresh=0) + self.strings = strings def __call__(self, Tokens tokens): - eg = Example(self.model.n_classes, CONTEXT_SIZE, self.model.n_feats, - self.model.n_feats) - cdef int i + cdef atom_t[CONTEXT_SIZE] context + cdef int i, guess, n_feats + cdef const TokenC* token for i in range(tokens.length): - n_valid = self._set_valid(eg.c.is_valid, pos_senses(&tokens.data[i])) - if n_valid >= 1: - fill_context(eg.c.atoms, &tokens.data[i]) - self.model.predict(eg) - tokens.data[i].sense = eg.c.guess + token = &tokens.data[i] + if token.pos in (NOUN, VERB): + fill_context(context, token) + feats = self.extractor.get_feats(context, &n_feats) + scores = self.model.get_scores(feats, n_feats) + tokens.data[i].sense = self.best_in_set(scores, POS_SENSES[token.pos]) def train(self, Tokens tokens, GoldParse gold): - eg = Example(self.model.n_classes, CONTEXT_SIZE, self.model.n_feats+1, - self.model.n_feats+1) - cdef int i + cdef int i, j for i, ssenses in enumerate(gold.ssenses): if ssenses: gold.c.ssenses[i] = encode_sense_strs(ssenses) else: gold.c.ssenses[i] = pos_senses(&tokens.data[i]) + + cdef atom_t[CONTEXT_SIZE] context + cdef int n_feats + cdef feat_t f_key + cdef int f_i cdef int cost = 0 for i in range(tokens.length): - if tokens.data[i].lex.senses == 0 or tokens.data[i].lex.senses == 1: - continue - self._set_costs(eg.c.is_valid, eg.c.costs, gold.c.ssenses[i]) - fill_context(eg.c.atoms, &tokens.data[i]) - - self.model.train(eg) - - tokens.data[i].sense = eg.c.guess - cost += eg.c.cost + token = &tokens.data[i] + if token.pos in (NOUN, VERB) \ + and token.lex.senses >= 2 \ + and gold.c.ssenses[i] >= 2: + fill_context(context, token) + feats = self.extractor.get_feats(context, &n_feats) + scores = self.model.get_scores(feats, n_feats) + token.sense = self.best_in_set(scores, POS_SENSES[token.pos]) + best = self.best_in_set(scores, gold.c.ssenses[i]) + guess_counts = {} + gold_counts = {} + if token.sense != best: + for j in range(n_feats): + f_key = feats[j].key + f_i = feats[j].i + feat = (f_i, f_key) + gold_counts[feat] = gold_counts.get(feat, 0) + 1.0 + guess_counts[feat] = guess_counts.get(feat, 0) - 1.0 + #self.model.update({token.sense: guess_counts, best: gold_counts}) return cost - cdef int _set_valid(self, bint* is_valid, flags_t senses) except -1: - cdef int n_valid - cdef flags_t bit - is_valid[0] = False - for bit in range(1, N_SENSES): - is_valid[bit] = senses & (1 << bit) - n_valid += is_valid[bit] - return n_valid - - cdef int _set_costs(self, bint* is_valid, int* costs, flags_t senses): - cdef flags_t bit - is_valid[0] = False - costs[0] = 1 - for bit in range(1, N_SENSES): - is_valid[bit] = True - costs[bit] = 0 if (senses & (1 << bit)) else 1 + cdef int best_in_set(self, const weight_t* scores, flags_t senses) except -1: + cdef weight_t max_ = 0 + cdef int argmax = -1 + cdef flags_t i + for i in range(N_SENSES): + if (senses & (1 << i)) and (argmax == -1 or scores[i] > max_): + max_ = scores[i] + argmax = i + assert argmax >= 0 + return argmax cdef flags_t pos_senses(const TokenC* token) nogil: