diff --git a/spacy/_ml.pxd b/spacy/_ml.pxd index b670b0f11..95afac99c 100644 --- a/spacy/_ml.pxd +++ b/spacy/_ml.pxd @@ -12,25 +12,34 @@ from .typedefs cimport hash_t, id_t from .tokens cimport Tokens +cdef int arg_max(const weight_t* scores, const int n_classes) nogil + + cdef class Model: - cdef weight_t* score(self, atom_t* context) except NULL - cdef class_t predict(self, atom_t* context) except * - cdef class_t predict_among(self, atom_t* context, bint* valid) except * - cdef class_t predict_and_update(self, atom_t* context, const bint* valid, - const int* costs) except * - + cdef Pool mem + cdef int n_classes + + cdef int update(self, atom_t* context, class_t guess, class_t gold, int cost) except -1 + cdef object model_loc cdef Extractor _extractor cdef LinearModel _model + cdef inline const weight_t* score(self, atom_t* context): + cdef int n_feats + feats = self._extractor.get_feats(context, &n_feats) + return self._model.get_scores(feats, n_feats) + cdef class HastyModel: - cdef class_t predict(self, atom_t* context) except * - cdef class_t predict_among(self, atom_t* context, bint* valid) except * - cdef class_t predict_and_update(self, atom_t* context, const bint* valid, - const int* costs) except * + cdef Pool mem + cdef weight_t* _scores - cdef weight_t confidence + cdef const weight_t* score(self, atom_t* context) except NULL + cdef int update(self, atom_t* context, class_t guess, class_t gold, int cost) except -1 + cdef int n_classes cdef Model _hasty cdef Model _full + cdef readonly int hasty_cnt + cdef readonly int full_cnt diff --git a/spacy/_ml.pyx b/spacy/_ml.pyx index 9f26af003..7ad4bee37 100644 --- a/spacy/_ml.pyx +++ b/spacy/_ml.pyx @@ -4,7 +4,6 @@ from __future__ import division from os import path import os -from collections import defaultdict import shutil import random import json @@ -13,80 +12,39 @@ import cython from thinc.features cimport Feature, count_feats -def setup_model_dir(tag_names, tag_map, templates, model_dir): - if path.exists(model_dir): - shutil.rmtree(model_dir) - os.mkdir(model_dir) - config = { - 'templates': templates, - 'tag_names': tag_names, - 'tag_map': tag_map - } - with open(path.join(model_dir, 'config.json'), 'w') as file_: - json.dump(config, file_) +cdef int arg_max(const weight_t* scores, const int n_classes) nogil: + cdef int i + cdef int best = 0 + cdef weight_t mode = scores[0] + for i in range(1, n_classes): + if scores[i] > mode: + mode = scores[i] + best = i + return best cdef class Model: def __init__(self, n_classes, templates, model_loc=None): + if model_loc is not None and path.isdir(model_loc): + model_loc = path.join(model_loc, 'model') + self.mem = Pool() + self.n_classes = n_classes self._extractor = Extractor(templates) self._model = LinearModel(n_classes, self._extractor.n_templ) self.model_loc = model_loc if self.model_loc and path.exists(self.model_loc): self._model.load(self.model_loc, freq_thresh=0) - cdef const weight_t* score(self, atom_t* context) except NULL: + cdef int update(self, atom_t* context, class_t guess, class_t gold, int cost) except -1: cdef int n_feats - cdef const Feature* feats = self._extractor.get_feats(context, &n_feats) - return self._model.get_scores(feats, n_feats) - - cdef class_t predict(self, atom_t* context) except *: - cdef weight_t _ - scores = self.score(context) - guess = _arg_max(scores, self._model.nr_class, &_) - return guess - - cdef class_t predict_among(self, atom_t* context, const bint* valid) except *: - cdef weight_t _ - scores = self.score(context) - return _arg_max_among(scores, valid, self._model.nr_class, &_) - - cdef class_t predict_and_update(self, atom_t* context, const bint* valid, - const int* costs) except *: - cdef: - int n_feats - const Feature* feats - const weight_t* scores - - int guess - int best - int cost - int i - weight_t score - weight_t _ - - feats = self._extractor.get_feats(context, &n_feats) - scores = self._model.get_scores(feats, n_feats) - guess = _arg_max_among(scores, valid, self._model.nr_class, &_) - cost = costs[guess] if cost == 0: self._model.update({}) - return guess - - guess_counts = defaultdict(int) - best_counts = defaultdict(int) - for i in range(n_feats): - feat = (feats[i].i, feats[i].key) - upd = feats[i].value * cost - best_counts[feat] += upd - guess_counts[feat] -= upd - best = -1 - score = 0 - for i in range(self._model.nr_class): - if valid[i] and costs[i] == 0 and (best == -1 or scores[i] > score): - best = i - score = scores[i] - self._model.update({guess: guess_counts, best: best_counts}) - return guess + else: + feats = self._extractor.get_feats(context, &n_feats) + counts = {gold: {}, guess: {}} + count_feats(counts[gold], feats, n_feats, cost) + count_feats(counts[guess], feats, n_feats, -cost) + self._model.update(counts) def end_training(self): self._model.end_training() @@ -94,41 +52,34 @@ cdef class Model: cdef class HastyModel: - def __init__(self, n_classes, hasty_templates, full_templates, model_dir, - weight_t confidence=0.1): + def __init__(self, n_classes, hasty_templates, full_templates, model_dir): + full_templates = tuple([t for t in full_templates if t not in hasty_templates]) + self.mem = Pool() self.n_classes = n_classes - self.confidence = confidence + self._scores = self.mem.alloc(self.n_classes, sizeof(weight_t)) + assert path.exists(model_dir) + assert path.isdir(model_dir) self._hasty = Model(n_classes, hasty_templates, path.join(model_dir, 'hasty_model')) self._full = Model(n_classes, full_templates, path.join(model_dir, 'full_model')) + self.hasty_cnt = 0 + self.full_cnt = 0 - cdef class_t predict(self, atom_t* context) except *: - cdef weight_t ratio - scores = self._hasty.score(context) - guess = _arg_max(scores, self.n_classes, &ratio) - if ratio < self.confidence: - return guess + cdef const weight_t* score(self, atom_t* context) except NULL: + cdef int i + hasty_scores = self._hasty.score(context) + if will_use_hasty(hasty_scores, self._hasty.n_classes): + self.hasty_cnt += 1 + return hasty_scores else: - return self._full.predict(context) + self.full_cnt += 1 + full_scores = self._full.score(context) + for i in range(self.n_classes): + self._scores[i] = full_scores[i] + hasty_scores[i] + return self._scores - cdef class_t predict_among(self, atom_t* context, bint* valid) except *: - cdef weight_t ratio - scores = self._hasty.score(context) - guess = _arg_max_among(scores, valid, self.n_classes, &ratio) - if ratio < self.confidence: - return guess - else: - return self._full.predict(context) - - cdef class_t predict_and_update(self, atom_t* context, bint* valid, int* costs) except *: - cdef weight_t ratio - scores = self._hasty.score(context) - _arg_max_among(scores, valid, self.n_classes, &ratio) - hasty_guess = self._hasty.predict_and_update(context, valid, costs) - full_guess = self._full.predict_and_update(context, valid, costs) - if ratio < self.confidence: - return hasty_guess - else: - return full_guess + cdef int update(self, atom_t* context, class_t guess, class_t gold, int cost) except -1: + self._hasty.update(context, guess, gold, cost) + self._full.update(context, guess, gold, cost) def end_training(self): self._hasty.end_training() @@ -136,31 +87,29 @@ cdef class HastyModel: @cython.cdivision(True) -cdef int _arg_max(const weight_t* scores, int n_classes, weight_t* ratio) except -1: - cdef int best = 0 - cdef weight_t score = scores[best] +cdef bint will_use_hasty(const weight_t* scores, int n_classes) nogil: + cdef: + weight_t best_score, second_score + int best, second + + if scores[0] >= scores[1]: + best = 0 + best_score = scores[0] + second = 1 + second_score = scores[1] + else: + best = 1 + best_score = scores[1] + second = 0 + second_score = scores[0] cdef int i - ratio[0] = 0.0 - for i in range(1, n_classes): - if scores[i] >= score: - if score > 0: - ratio[0] = score / scores[i] - score = scores[i] + for i in range(2, n_classes): + if scores[i] > best_score: + second_score = best_score + second = best best = i - return best - - -@cython.cdivision(True) -cdef int _arg_max_among(const weight_t* scores, const bint* valid, int n_classes, - weight_t* ratio) except -1: - cdef int clas - cdef weight_t score = 0 - cdef int best = -1 - ratio[0] = 0 - for clas in range(n_classes): - if valid[clas] and (best == -1 or scores[clas] > score): - if score > 0: - ratio[0] = score / scores[clas] - score = scores[clas] - best = clas - return best + best_score = scores[i] + elif scores[i] > second_score: + second_score = scores[i] + second = i + return best_score > 0 and second_score < (best_score / 2) diff --git a/spacy/en/__init__.py b/spacy/en/__init__.py index 80efb9cad..df2b26b42 100644 --- a/spacy/en/__init__.py +++ b/spacy/en/__init__.py @@ -82,16 +82,13 @@ class English(object): tokens (spacy.tokens.Tokens): """ tokens = self.tokenizer.tokenize(text) - if self.tagger and tag: + if tag: self.tagger(tokens) - if self.parser and parse: + if parse: self.parser.parse(tokens) return tokens @property def tags(self): """List of part-of-speech tag names.""" - if self.tagger is None: - return [] - else: - return self.tagger.tag_names + return self.tagger.tag_names diff --git a/spacy/en/pos.pyx b/spacy/en/pos.pyx index f8ebc146a..67719ceb0 100644 --- a/spacy/en/pos.pyx +++ b/spacy/en/pos.pyx @@ -1,11 +1,13 @@ # cython: profile=True from os import path import json +import os +import shutil from libc.string cimport memset from cymem.cymem cimport Address -from thinc.typedefs cimport atom_t +from thinc.typedefs cimport atom_t, weight_t from ..typedefs cimport univ_tag_t from ..typedefs cimport NO_TAG, ADJ, ADV, ADP, CONJ, DET, NOUN, NUM, PRON, PRT, VERB @@ -14,6 +16,8 @@ from ..typedefs cimport id_t from ..structs cimport TokenC, Morphology, Lexeme from ..tokens cimport Tokens from ..morphology cimport set_morph_from_dict +from .._ml cimport arg_max + from .lemmatizer import Lemmatizer @@ -206,6 +210,19 @@ cdef struct _CachedMorph: int lemma +def setup_model_dir(tag_names, tag_map, templates, model_dir): + if path.exists(model_dir): + shutil.rmtree(model_dir) + os.mkdir(model_dir) + config = { + 'templates': templates, + 'tag_names': tag_names, + 'tag_map': tag_map + } + with open(path.join(model_dir, 'config.json'), 'w') as file_: + json.dump(config, file_) + + cdef class EnPosTagger: """A part-of-speech tagger for English""" def __init__(self, StringStore strings, data_dir): @@ -218,8 +235,8 @@ cdef class EnPosTagger: self.tag_map = cfg['tag_map'] cdef int n_tags = len(self.tag_names) + 1 - self.model = Model(n_tags, cfg['templates'], model_dir=model_dir) - + hasty_templates = ((W_sic,), (P1_pos, P2_pos), (N1_sic,)) + self.model = Model(n_tags, cfg['templates'], model_dir) self._morph_cache = PreshMapArray(n_tags) self.tags = self.mem.alloc(n_tags, sizeof(PosTag)) for i, tag in enumerate(sorted(self.tag_names)): @@ -239,30 +256,27 @@ cdef class EnPosTagger: """ cdef int i cdef atom_t[N_CONTEXT_FIELDS] context - cdef TokenC* t = tokens.data + cdef const weight_t* scores for i in range(tokens.length): - if t[i].fine_pos == 0: - fill_context(context, i, t) - t[i].fine_pos = self.model.predict(context) - self.set_morph(i, t) + if tokens.data[i].fine_pos == 0: + fill_context(context, i, tokens.data) + scores = self.model.score(context) + tokens.data[i].fine_pos = arg_max(scores, self.model.n_classes) + self.set_morph(i, tokens.data) - def train(self, Tokens tokens, py_golds): + def train(self, Tokens tokens, object golds): cdef int i cdef atom_t[N_CONTEXT_FIELDS] context - cdef Address costs_mem = Address(self.n_tags, sizeof(int)) - cdef Address valid_mem = Address(self.n_tags, sizeof(bint)) - cdef int* costs = costs_mem.ptr - cdef bint* valid = valid_mem.ptr - memset(valid, 1, sizeof(int) * self.n_tags) + cdef const weight_t* scores correct = 0 - cdef TokenC* t = tokens.data for i in range(tokens.length): - fill_context(context, i, t) - memset(costs, 1, sizeof(int) * self.n_tags) - costs[py_golds[i]] = 0 - t[i].fine_pos = self.model.predict_and_update(context, valid, costs) - self.set_morph(i, t) - correct += costs[t[i].fine_pos] == 0 + fill_context(context, i, tokens.data) + scores = self.model.score(context) + guess = arg_max(scores, self.model.n_classes) + self.model.update(context, guess, golds[i], guess != golds[i]) + tokens.data[i].fine_pos = guess + self.set_morph(i, tokens.data) + correct += guess == golds[i] return correct cdef int set_morph(self, const int i, TokenC* tokens) except -1: diff --git a/spacy/syntax/_parse_features.pyx b/spacy/syntax/_parse_features.pyx index 5c7f39ff7..bfdc94029 100644 --- a/spacy/syntax/_parse_features.pyx +++ b/spacy/syntax/_parse_features.pyx @@ -85,7 +85,6 @@ cdef int fill_context(atom_t* context, State* state) except -1: if state.stack_len >= 3: context[S2_has_head] = has_head(get_s2(state)) - unigrams = ( (S2W, S2p), (S2c6, S2p), @@ -347,6 +346,9 @@ clusters = ( ) +hasty = s0_n0 + n0_n1 + trigrams + + def pos_bigrams(): kernels = [S2w, S1w, S0w, S0lw, S0rw, N0w, N0lw, N1w] bitags = [] diff --git a/spacy/syntax/parser.pxd b/spacy/syntax/parser.pxd index 224ea490b..e4fc38969 100644 --- a/spacy/syntax/parser.pxd +++ b/spacy/syntax/parser.pxd @@ -1,5 +1,4 @@ -from thinc.features cimport Extractor -from thinc.learner cimport LinearModel +from .._ml cimport Model, HastyModel from .arc_eager cimport TransitionSystem @@ -8,8 +7,7 @@ from ..tokens cimport Tokens, TokenC cdef class GreedyParser: cdef object cfg - cdef Extractor extractor - cdef readonly LinearModel model + cdef readonly Model model cdef TransitionSystem moves cpdef int parse(self, Tokens tokens) except -1 diff --git a/spacy/syntax/parser.pyx b/spacy/syntax/parser.pyx index 9b23cea1c..755766eb3 100644 --- a/spacy/syntax/parser.pyx +++ b/spacy/syntax/parser.pyx @@ -7,7 +7,7 @@ cimport cython from libc.stdint cimport uint32_t, uint64_t import random import os.path -from os.path import join as pjoin +from os import path import shutil import json @@ -52,26 +52,23 @@ cdef unicode print_state(State* s, list words): def get_templates(name): pf = _parse_features if name == 'zhang': - return pf.arc_eager + return pf.unigrams, pf.arc_eager else: - return pf.unigrams + pf.s0_n0 + pf.s1_n0 + pf.s0_n1 + pf.n0_n1 + \ - pf.tree_shape + pf.trigrams + return pf.hasty, (pf.unigrams + pf.s0_n0 + pf.s1_n0 + pf.s0_n1 + pf.n0_n1 + \ + pf.tree_shape + pf.trigrams) cdef class GreedyParser: def __init__(self, model_dir): assert os.path.exists(model_dir) and os.path.isdir(model_dir) self.cfg = Config.read(model_dir, 'config') - self.extractor = Extractor(get_templates(self.cfg.features)) self.moves = TransitionSystem(self.cfg.left_labels, self.cfg.right_labels) - self.model = LinearModel(self.moves.n_moves, self.extractor.n_templ) - if os.path.exists(pjoin(model_dir, 'model')): - self.model.load(pjoin(model_dir, 'model')) + hasty_templ, full_templ = get_templates(self.cfg.features) + #self.model = HastyModel(self.moves.n_moves, hasty_templ, full_templ, model_dir) + self.model = Model(self.moves.n_moves, full_templ, model_dir) cpdef int parse(self, Tokens tokens) except -1: cdef: - const Feature* feats - const weight_t* scores Transition guess uint64_t state_key @@ -81,8 +78,7 @@ cdef class GreedyParser: cdef State* state = init_state(mem, tokens.data, tokens.length) while not is_final(state): fill_context(context, state) - feats = self.extractor.get_feats(context, &n_feats) - scores = self.model.get_scores(feats, n_feats) + scores = self.model.score(context) guess = self.moves.best_valid(scores, state) self.moves.transition(state, &guess) return 0 @@ -106,35 +102,13 @@ cdef class GreedyParser: cdef State* state = init_state(mem, tokens.data, tokens.length) while not is_final(state): - fill_context(context, state) - feats = self.extractor.get_feats(context, &n_feats) - scores = self.model.get_scores(feats, n_feats) + fill_context(context, state) + scores = self.model.score(context) guess = self.moves.best_valid(scores, state) best = self.moves.best_gold(&guess, scores, state, heads_array, labels_array) - counts = _get_counts(guess.clas, best.clas, feats, n_feats, guess.cost) - self.model.update(counts) + self.model.update(context, guess.clas, best.clas, guess.cost) self.moves.transition(state, &guess) cdef int n_corr = 0 for i in range(tokens.length): n_corr += (i + state.sent[i].head) == gold_heads[i] return n_corr - - -cdef dict _get_counts(int guess, int best, const Feature* feats, const int n_feats, - int inc): - if guess == best: - return {} - - gold_counts = {} - guess_counts = {} - cdef int i - for i in range(n_feats): - key = (feats[i].i, feats[i].key) - if key in gold_counts: - gold_counts[key] += (feats[i].value * inc) - guess_counts[key] -= (feats[i].value * inc) - else: - gold_counts[key] = (feats[i].value * inc) - guess_counts[key] = -(feats[i].value * inc) - return {guess: guess_counts, best: gold_counts} -