diff --git a/spacy/ner/_state.pxd b/spacy/ner/_state.pxd index 43b37d3bd..d9ec1e8f1 100644 --- a/spacy/ner/_state.pxd +++ b/spacy/ner/_state.pxd @@ -6,6 +6,7 @@ cdef int begin_entity(State* s, label) except -1 cdef int end_entity(State* s) except -1 cdef State* init_state(Pool mem, int sent_length) except NULL +cdef int copy_state(Pool mem, State* dest, State* source) except -1 cdef bint entity_is_open(State *s) except -1 diff --git a/spacy/ner/_state.pyx b/spacy/ner/_state.pyx index 7f1892371..7d42cc799 100644 --- a/spacy/ner/_state.pyx +++ b/spacy/ner/_state.pyx @@ -1,44 +1,50 @@ -from .bilou_moves cimport BEGIN, UNIT - - -cdef int begin_entity(State* s, label) except -1: - s.curr.start = s.i - s.curr.label = label - - -cdef int end_entity(State* s) except -1: - s.curr.end = s.i - s.ents[s.j] = s.curr +cdef void begin_entity(State* s, label): s.j += 1 - s.curr.start = 0 - s.curr.label = -1 - s.curr.end = 0 + s.ents[s.j].start = s.i + s.ents[s.j].tag = label + s.ents[s.j].end = s.i + 1 + + +cdef void end_entity(State* s): + s.ents[s.j].end = s.i + 1 cdef State* init_state(Pool mem, int sent_length) except NULL: s = mem.alloc(1, sizeof(State)) - s.j = 0 s.ents = mem.alloc(sent_length, sizeof(Entity)) - for i in range(sent_length): - s.ents[i].label = -1 - s.curr.label = -1 s.tags = mem.alloc(sent_length, sizeof(int)) s.length = sent_length - return s -cdef bint entity_is_open(State *s) except -1: - return s.curr.label != -1 +cdef bint entity_is_open(State *s): + return s.ents[s.j].start != 0 -cdef bint entity_is_sunk(State *s, Move* golds) except -1: +cdef bint entity_is_sunk(State *s, Move* golds): if not entity_is_open(s): return False - cdef Move* gold = &golds[s.curr.start] + cdef Entity* ent = &s.ents[s.j] + cdef Move* gold = &golds[ent.start] if gold.action != BEGIN and gold.action != UNIT: return True - elif gold.label != s.curr.label: + elif gold.label != ent.label: return True else: return False + + +cdef int copy_state(Pool mem, State* dest, State* source) except -1: + '''Copy state source into state dest.''' + if source.length > dest.length: + dest.ents = mem.realloc(dest.ents, source.length * sizeof(Entity)) + dest.tags = mem.realloc(dest.tags, source.length * sizeof(int)) + memcpy(dest.ents, source.ents, source.length * sizeof(Entity)) + memcpy(dest.tags, source.tags, source.length * sizeof(int)) + dest.length = source.length + dest.i = source.i + dest.j = source.j + dest.curr = source.curr + + + diff --git a/spacy/ner/greedy_parser.pyx b/spacy/ner/greedy_parser.pyx index 5825c7539..2e3af5717 100644 --- a/spacy/ner/greedy_parser.pyx +++ b/spacy/ner/greedy_parser.pyx @@ -1,6 +1,3 @@ -from __future__ import division -from __future__ import unicode_literals - cimport cython import random import os @@ -10,56 +7,14 @@ import json from thinc.features cimport ConjFeat -from .context cimport fill_context -from .context cimport N_FIELDS -from .structs cimport Move, State -from .io_moves cimport fill_moves, transition, best_accepted -from .io_moves cimport set_accept_if_valid, set_accept_if_oracle -from .io_moves import get_n_moves +from ..context cimport fill_context +from ..context cimport N_FIELDS +from .moves cimport Move +from .moves cimport fill_moves, transition, best_accepted +from .moves cimport set_accept_if_valid, set_accept_if_oracle +from .moves import get_n_moves +from ._state cimport State from ._state cimport init_state -from ._state cimport entity_is_open -from ._state cimport end_entity -from .annot cimport NERAnnotation - - -def setup_model_dir(entity_types, templates, model_dir): - if path.exists(model_dir): - shutil.rmtree(model_dir) - os.mkdir(model_dir) - config = { - 'templates': templates, - 'entity_types': entity_types, - } - with open(path.join(model_dir, 'config.json'), 'w') as file_: - json.dump(config, file_) - - -def train(train_sents, model_dir, nr_iter=10): - cdef Tokens tokens - cdef NERAnnotation gold_ner - parser = NERParser(model_dir) - for _ in range(nr_iter): - tp = 0 - fp = 0 - fn = 0 - for i, (tokens, gold_ner) in enumerate(train_sents): - #print [tokens[i].string for i in range(tokens.length)] - test_ents = set(parser.train(tokens, gold_ner)) - #print 'Test', test_ents - gold_ents = set(gold_ner.entities) - #print 'Gold', set(gold_ner.entities) - tp += len(gold_ents.intersection(test_ents)) - fp += len(test_ents - gold_ents) - fn += len(gold_ents - test_ents) - p = tp / (tp + fp) - r = tp / (tp + fn) - f = 2 * ((p * r) / (p + r)) - print 'P: %.3f' % p, - print 'R: %.3f' % r, - print 'F: %.3f' % f - random.shuffle(train_sents) - parser.model.end_training() - parser.model.dump(path.join(model_dir, 'model')) cdef class NERParser: @@ -67,12 +22,12 @@ cdef class NERParser: self.mem = Pool() cfg = json.load(open(path.join(model_dir, 'config.json'))) templates = cfg['templates'] - self.extractor = Extractor(templates, [ConjFeat] * len(templates)) self.entity_types = cfg['entity_types'] + self.extractor = Extractor(templates, [ConjFeat] * len(templates)) self.n_classes = get_n_moves(len(self.entity_types)) self._moves = self.mem.alloc(self.n_classes, sizeof(Move)) - fill_moves(self._moves, self.n_classes, self.entity_types) - self.model = LinearModel(self.n_classes) + fill_moves(self._moves, len(self.entity_types)) + self.model = LinearModel(len(self.tag_names)) if path.exists(path.join(model_dir, 'model')): self.model.load(path.join(model_dir, 'model')) @@ -81,59 +36,46 @@ cdef class NERParser: self._values = self.mem.alloc(self.extractor.n+1, sizeof(weight_t)) self._scores = self.mem.alloc(self.model.nr_class, sizeof(weight_t)) - cpdef list train(self, Tokens tokens, NERAnnotation annot): + cpdef int train(self, Tokens tokens, gold_classes): cdef Pool mem = Pool() cdef State* s = init_state(mem, tokens.length) + cdef Move* golds = mem.alloc(len(gold_classes), sizeof(Move)) + for i, clas in enumerate(gold_classes): + golds[i] = self.moves[clas - 1] + assert golds[i].id == clas cdef Move* guess - cdef Move* oracle_move - n_correct = 0 - cdef int f = 0 while s.i < tokens.length: - fill_context(self._context, s, tokens) + fill_context(self._context, s.i, tokens) self.extractor.extract(self._feats, self._values, self._context, NULL) self.model.score(self._scores, self._feats, self._values) set_accept_if_valid(self._moves, self.n_classes, s) guess = best_accepted(self._moves, self._scores, self.n_classes) - assert guess.clas != 0 - set_accept_if_oracle(self._moves, self.n_classes, s, - annot.starts, annot.ends, annot.labels) - oracle_move = best_accepted(self._moves, self._scores, self.n_classes) - assert oracle_move.clas != 0 - if guess.clas == oracle_move.clas: - counts = {} - n_correct += 1 - else: - counts = {guess.clas: {}, oracle_move.clas: {}} - self.extractor.count(counts[oracle_move.clas], self._feats, 1) - self.extractor.count(counts[guess.clas], self._feats, -1) + + set_accept_if_oracle(self._moves, golds, self.n_classes, s) # TODO + gold = best_accepted(self._moves, self._scores, self.n_classes) + + if guess.clas == gold.clas: + self.model.update({}) + return 0 + + counts = {guess.clas: {}, gold.clas: {}} + self.extractor.count(counts[gold.clas], self._feats, 1) + self.extractor.count(counts[guess.clas], self._feats, -1) self.model.update(counts) + transition(s, guess) tokens.ner[s.i-1] = s.tags[s.i-1] - if entity_is_open(s): - s.curr.label = annot.labels[s.curr.start] - end_entity(s) - entities = [] - for i in range(s.j): - entities.append((s.ents[i].start, s.ents[i].end, s.ents[i].label)) - return entities - cpdef list set_tags(self, Tokens tokens): + cpdef int set_tags(self, Tokens tokens) except -1: cdef Pool mem = Pool() cdef State* s = init_state(mem, tokens.length) cdef Move* move while s.i < tokens.length: - fill_context(self._context, s, tokens) + fill_context(self._context, s.i, tokens) self.extractor.extract(self._feats, self._values, self._context, NULL) self.model.score(self._scores, self._feats, self._values) set_accept_if_valid(self._moves, self.n_classes, s) move = best_accepted(self._moves, self._scores, self.n_classes) transition(s, move) tokens.ner[s.i-1] = s.tags[s.i-1] - if entity_is_open(s): - s.curr.label = move.label - end_entity(s) - entities = [] - for i in range(s.j): - entities.append((s.ents[i].start, s.ents[i].end, s.ents[i].label)) - return entities diff --git a/spacy/ner/io_moves.pyx b/spacy/ner/io_moves.pyx index dc268e4a5..6e892ddf5 100644 --- a/spacy/ner/io_moves.pyx +++ b/spacy/ner/io_moves.pyx @@ -77,11 +77,16 @@ cdef int set_accept_if_valid(Move* moves, int n, State* s) except 0: moves[0].accept = False for i in range(1, n): if moves[i].action == SHIFT: - moves[i].accept = moves[i].label == s.curr.label or not entity_is_open(s) + if s.i >= s.length: + moves[i].accept = False + elif open_ent and moves[i].label != s.curr.label: + moves[i].accept = False + else: + moves[i].accept = True elif moves[i].action == REDUCE: moves[i].accept = open_ent elif moves[i].action == OUT: - moves[i].accept = not open_ent + moves[i].accept = s.i < s.length and not open_ent n_accept += moves[i].accept return n_accept @@ -150,3 +155,7 @@ cdef int fill_moves(Move* moves, int n, list entity_types) except -1: moves[i].clas = i moves[i].label = 0 i += 1 + + +cdef bint is_final(State* s): + return s.i == s.length and not entity_is_open(s)