diff --git a/bin/parser/train.py b/bin/parser/train.py index e53a4edb4..bc1c91d70 100755 --- a/bin/parser/train.py +++ b/bin/parser/train.py @@ -24,6 +24,8 @@ from spacy.syntax.util import Config from spacy.syntax.conll import read_docparse_file from spacy.syntax.conll import GoldParse +from spacy.scorer import Scorer + def is_punct_label(label): return label == 'P' or label.lower() == 'punct' @@ -186,7 +188,6 @@ def get_labels(sents): def train(Language, train_loc, model_dir, n_iter=15, feat_set=u'basic', seed=0, gold_preproc=False, force_gold=False, n_sents=0): - print "Setup model dir" dep_model_dir = path.join(model_dir, 'deps') pos_model_dir = path.join(model_dir, 'pos') ner_model_dir = path.join(model_dir, 'ner') @@ -209,13 +210,16 @@ def train(Language, train_loc, model_dir, n_iter=15, feat_set=u'basic', seed=0, Config.write(ner_model_dir, 'config', features='ner', seed=seed, labels=Language.EntityTransitionSystem.get_labels(gold_tuples)) + if n_sents > 0: + gold_tuples = gold_tuples[:n_sents] nlp = Language() - + ent_strings = [None] * (max(nlp.entity.moves.label_ids.values()) + 1) + for label, i in nlp.entity.moves.label_ids.items(): + ent_strings[i] = label + + print "Itn.\tUAS\tNER F.\tTag %" for itn in range(n_iter): - dep_corr = 0 - pos_corr = 0 - ent_corr = 0 - n_tokens = 0 + scorer = Scorer() for raw_text, segmented_text, annot_tuples in gold_tuples: if gold_preproc: sents = [nlp.tokenizer.tokens_from_list(s) for s in segmented_text] @@ -224,51 +228,32 @@ def train(Language, train_loc, model_dir, n_iter=15, feat_set=u'basic', seed=0, for tokens in sents: gold = GoldParse(tokens, annot_tuples) nlp.tagger(tokens) - #ent_corr += nlp.entity.train(tokens, gold, force_gold=force_gold) - dep_corr += nlp.parser.train(tokens, gold, force_gold=force_gold) - pos_corr += nlp.tagger.train(tokens, gold.tags) - n_tokens += len(tokens) - acc = float(dep_corr) / n_tokens - pos_acc = float(pos_corr) / n_tokens - print '%d: ' % itn, '%.3f' % acc, '%.3f' % pos_acc + nlp.entity.train(tokens, gold, force_gold=force_gold) + #nlp.parser.train(tokens, gold, force_gold=force_gold) + nlp.tagger.train(tokens, gold.tags) + + nlp.entity(tokens) + tokens._ent_strings = tuple(ent_strings) + nlp.parser(tokens) + scorer.score(tokens, gold, verbose=False) + print '%d:\t%.3f\t%.3f\t%.3f' % (itn, scorer.uas, scorer.ents_f, scorer.tags_acc) random.shuffle(gold_tuples) nlp.parser.model.end_training() + nlp.entity.model.end_training() nlp.tagger.model.end_training() - return acc def evaluate(Language, dev_loc, model_dir, gold_preproc=False): global loss + assert not gold_preproc nlp = Language() - uas_corr = 0 - las_corr = 0 - pos_corr = 0 - n_tokens = 0 - total = 0 - skipped = 0 - loss = 0 - gold_tuples = read_docparse_file(train_loc) + gold_tuples = read_docparse_file(dev_loc) + scorer = Scorer() for raw_text, segmented_text, annot_tuples in gold_tuples: - if gold_preproc: - tokens = nlp.tokenizer.tokens_from_list(gold_sent.words) - nlp.tagger(tokens) - nlp.parser(tokens) - gold_sent.map_heads(nlp.parser.moves.label_ids) - else: - tokens = nlp(gold_sent.raw_text) - loss += gold_sent.align_to_tokens(tokens, nlp.parser.moves.label_ids) - for i, token in enumerate(tokens): - pos_corr += token.tag_ == gold_sent.tags[i] - n_tokens += 1 - if gold_sent.heads[i] is None: - skipped += 1 - continue - if gold_sent.labels[i] != 'P': - n_corr += gold_sent.is_correct(i, token.head.i) - total += 1 - print loss, skipped, (loss+skipped + total) - print pos_corr / n_tokens - return float(n_corr) / (total + loss) + tokens = nlp(raw_text) + gold = GoldParse(tokens, annot_tuples) + scorer.score(tokens, gold, verbose=False) + return scorer @@ -281,7 +266,14 @@ def evaluate(Language, dev_loc, model_dir, gold_preproc=False): def main(train_loc, dev_loc, model_dir, n_sents=0): train(English, train_loc, model_dir, gold_preproc=False, force_gold=False, n_sents=n_sents) - print evaluate(English, dev_loc, model_dir, gold_preproc=False) + scorer = evaluate(English, dev_loc, model_dir, gold_preproc=False) + print 'POS', scorer.tags_acc + print 'UAS', scorer.uas + print 'LAS', scorer.las + + print 'NER P', scorer.ents_p + print 'NER R', scorer.ents_r + print 'NER F', scorer.ents_f if __name__ == '__main__': diff --git a/spacy/en/__init__.py b/spacy/en/__init__.py index 7bc1d36fd..dec591d55 100644 --- a/spacy/en/__init__.py +++ b/spacy/en/__init__.py @@ -180,7 +180,12 @@ class English(object): if parse and self.has_parser_model: self.parser(tokens) if entity and self.has_entity_model: + # TODO: Clean this up self.entity(tokens) + ent_strings = [None] * (max(self.entity.moves.label_ids.values()) + 1) + for label, i in self.entity.moves.label_ids.items(): + ent_strings[i] = label + tokens._ent_strings = tuple(ent_strings) return tokens @property diff --git a/spacy/structs.pxd b/spacy/structs.pxd index c7e39f77a..c1fc13ecd 100644 --- a/spacy/structs.pxd +++ b/spacy/structs.pxd @@ -45,14 +45,12 @@ cdef struct PosTag: cdef struct Entity: int start int end - int tag int label - + cdef struct TokenC: const LexemeC* lex Morphology morph - Entity ent univ_pos_t pos int tag int idx @@ -64,6 +62,9 @@ cdef struct TokenC: uint32_t l_kids uint32_t r_kids + int ent_iob + int ent_type + cdef struct Utf8Str: id_t i diff --git a/spacy/syntax/_parse_features.pxd b/spacy/syntax/_parse_features.pxd index 2a2d50272..64bf5515a 100644 --- a/spacy/syntax/_parse_features.pxd +++ b/spacy/syntax/_parse_features.pxd @@ -16,6 +16,7 @@ cdef int fill_context(atom_t* context, State* state) except -1 # S0w, # S0r0w, S0r2w, S0rw, # N0l0w, N0l2w, N0lw, +# P2w, P1w, # N0w, N1w, N2w, N3w, 0 #] @@ -28,6 +29,9 @@ cpdef enum: S2c4 S2c6 S2L + S2_prefix + S2_suffix + S2_shape S1w S1W @@ -36,6 +40,9 @@ cpdef enum: S1c4 S1c6 S1L + S1_prefix + S1_suffix + S1_shape S1rw S1rW @@ -44,6 +51,9 @@ cpdef enum: S1rc4 S1rc6 S1rL + S1r_prefix + S1r_suffix + S1r_shape S0lw S0lW @@ -52,6 +62,9 @@ cpdef enum: S0lc4 S0lc6 S0lL + S0l_prefix + S0l_suffix + S0l_shape S0l2w S0l2W @@ -60,6 +73,9 @@ cpdef enum: S0l2c4 S0l2c6 S0l2L + S0l2_prefix + S0l2_suffix + S0l2_shape S0w S0W @@ -68,6 +84,9 @@ cpdef enum: S0c4 S0c6 S0L + S0_prefix + S0_suffix + S0_shape S0r2w S0r2W @@ -76,6 +95,9 @@ cpdef enum: S0r2c4 S0r2c6 S0r2L + S0r2_prefix + S0r2_suffix + S0r2_shape S0rw S0rW @@ -84,6 +106,9 @@ cpdef enum: S0rc4 S0rc6 S0rL + S0r_prefix + S0r_suffix + S0r_shape N0l2w N0l2W @@ -92,6 +117,9 @@ cpdef enum: N0l2c4 N0l2c6 N0l2L + N0l2_prefix + N0l2_suffix + N0l2_shape N0lw N0lW @@ -100,6 +128,9 @@ cpdef enum: N0lc4 N0lc6 N0lL + N0l_prefix + N0l_suffix + N0l_shape N0w N0W @@ -108,6 +139,9 @@ cpdef enum: N0c4 N0c6 N0L + N0_prefix + N0_suffix + N0_shape N1w N1W @@ -116,7 +150,10 @@ cpdef enum: N1c4 N1c6 N1L - + N1_prefix + N1_suffix + N1_shape + N2w N2W N2p @@ -124,7 +161,32 @@ cpdef enum: N2c4 N2c6 N2L + N2_prefix + N2_suffix + N2_shape + + P1w + P1W + P1p + P1c + P1c4 + P1c6 + P1L + P1_prefix + P1_suffix + P1_shape + P2w + P2W + P2p + P2c + P2c4 + P2c6 + P2L + P2_prefix + P2_suffix + P2_shape + # Misc features at the end dist N0lv diff --git a/spacy/syntax/_parse_features.pyx b/spacy/syntax/_parse_features.pyx index d54998347..a204526cc 100644 --- a/spacy/syntax/_parse_features.pyx +++ b/spacy/syntax/_parse_features.pyx @@ -12,6 +12,7 @@ from itertools import combinations from ..tokens cimport TokenC from ._state cimport State from ._state cimport get_s2, get_s1, get_s0, get_n0, get_n1, get_n2 +from ._state cimport get_p2, get_p1 from ._state cimport has_head, get_left, get_right from ._state cimport count_left_kids, count_right_kids @@ -45,6 +46,9 @@ cdef inline void fill_token(atom_t* context, const TokenC* token) nogil: context[4] = token.lex.cluster & 63 context[5] = token.lex.cluster & 15 context[6] = token.dep if has_head(token) else 0 + context[7] = token.lex.prefix + context[8] = token.lex.suffix + context[9] = token.lex.shape cdef int fill_context(atom_t* context, State* state) except -1: @@ -62,7 +66,8 @@ cdef int fill_context(atom_t* context, State* state) except -1: fill_token(&context[N0l2w], get_left(state, get_n0(state), 2)) fill_token(&context[N0w], get_n0(state)) fill_token(&context[N1w], get_n1(state)) - fill_token(&context[N2w], get_n2(state)) + fill_token(&context[P1w], get_p1(state)) + fill_token(&context[P2w], get_p2(state)) if state.stack_len >= 1: context[dist] = state.stack[0] - state.i @@ -84,6 +89,54 @@ cdef int fill_context(atom_t* context, State* state) except -1: if state.stack_len >= 3: context[S2_has_head] = has_head(get_s2(state)) +ner = ( + (N0w,), + (P1w,), + (N1w,), + (P2w,), + (N2w,), + + (P1w, N0w,), + (N0w, N1w), + + (N0_prefix,), + (N0_suffix,), + + (P1_shape,), + (N0_shape,), + (N1_shape,), + (P1_shape, N0_shape,), + (N0_shape, P1_shape,), + (P1_shape, N0_shape, N1_shape), + (N2_shape,), + (P2_shape,), + + #(P2_norm, P1_norm, W_norm), + #(P1_norm, W_norm, N1_norm), + #(W_norm, N1_norm, N2_norm) + + (P2p,), + (P1p,), + (N0p,), + (N1p,), + (N2p,), + + (P1p, N0p), + (N0p, N1p), + (P2p, P1p, N0p), + (P1p, N0p, N1p), + (N0p, N1p, N2p), + + (P2c,), + (P1c,), + (N0c,), + (N1c,), + (N2c,), + + (P1c, N0c), + (N0c, N1c), +) + unigrams = ( (S2W, S2p), diff --git a/spacy/syntax/_state.pxd b/spacy/syntax/_state.pxd index 6581e0648..9936bc33f 100644 --- a/spacy/syntax/_state.pxd +++ b/spacy/syntax/_state.pxd @@ -40,6 +40,21 @@ cdef inline TokenC* get_n1(const State* s) nogil: return &s.sent[s.i+1] +cdef inline TokenC* get_p1(const State* s) nogil: + if s.i < 1: + return NULL + else: + return &s.sent[s.i-1] + + +cdef inline TokenC* get_p2(const State* s) nogil: + if s.i < 2: + return NULL + else: + return &s.sent[s.i-2] + + + cdef inline TokenC* get_n2(const State* s) nogil: if (s.i + 2) >= s.sent_len: return NULL @@ -77,7 +92,7 @@ cdef int head_in_buffer(const State *s, const int child, const int* gold) except cdef int children_in_stack(const State *s, const int head, const int* gold) except -1 cdef int head_in_stack(const State *s, const int child, const int* gold) except -1 -cdef State* init_state(Pool mem, TokenC* sent, const int sent_length) except NULL +cdef State* new_state(Pool mem, TokenC* sent, const int sent_length) except NULL cdef int count_left_kids(const TokenC* head) nogil diff --git a/spacy/syntax/_state.pyx b/spacy/syntax/_state.pyx index 072bc648a..b2dbe3772 100644 --- a/spacy/syntax/_state.pyx +++ b/spacy/syntax/_state.pyx @@ -2,7 +2,7 @@ from libc.string cimport memmove, memcpy from cymem.cymem cimport Pool from ..lexeme cimport EMPTY_LEXEME -from ..structs cimport TokenC +from ..structs cimport TokenC, Entity DEF PADDING = 5 @@ -112,13 +112,15 @@ cdef int count_right_kids(const TokenC* head) nogil: return _popcount(head.r_kids) -cdef State* init_state(Pool mem, const TokenC* sent, const int sent_len) except NULL: +cdef State* new_state(Pool mem, const TokenC* sent, const int sent_len) except NULL: cdef int padded_len = sent_len + PADDING + PADDING cdef State* s = mem.alloc(1, sizeof(State)) + s.ent = mem.alloc(padded_len, sizeof(Entity)) s.stack = mem.alloc(padded_len, sizeof(int)) for i in range(PADDING): s.stack[i] = -1 s.stack += (PADDING - 1) + s.ent += (PADDING - 1) assert s.stack[0] == -1 state_sent = mem.alloc(padded_len, sizeof(TokenC)) memcpy(state_sent, sent - PADDING, padded_len * sizeof(TokenC)) @@ -126,5 +128,4 @@ cdef State* init_state(Pool mem, const TokenC* sent, const int sent_len) except s.stack_len = 0 s.i = 0 s.sent_len = sent_len - push_stack(s) return s diff --git a/spacy/syntax/arc_eager.pyx b/spacy/syntax/arc_eager.pyx index d757d28c1..b705086e8 100644 --- a/spacy/syntax/arc_eager.pyx +++ b/spacy/syntax/arc_eager.pyx @@ -58,7 +58,6 @@ cdef class ArcEager(TransitionSystem): gold.c_heads[i] = gold.heads[i] gold.c_labels[i] = self.label_ids[gold.labels[i]] - cdef Transition lookup_transition(self, object name) except *: if '-' in name: move_str, label_str = name.split('-', 1) @@ -82,6 +81,9 @@ cdef class ArcEager(TransitionSystem): t.get_cost = get_cost_funcs[move] return t + cdef int first_state(self, State* state) except -1: + push_stack(state) + cdef Transition best_valid(self, const weight_t* scores, const State* s) except *: cdef bint[N_MOVES] is_valid is_valid[SHIFT] = _can_shift(s) diff --git a/spacy/syntax/conll.pxd b/spacy/syntax/conll.pxd index 626cc699d..60583969b 100644 --- a/spacy/syntax/conll.pxd +++ b/spacy/syntax/conll.pxd @@ -14,6 +14,7 @@ cdef class GoldParse: cdef readonly list heads cdef readonly list labels cdef readonly list ner + cdef readonly list ents cdef int* c_tags cdef int* c_heads diff --git a/spacy/syntax/conll.pyx b/spacy/syntax/conll.pyx index 1f4252138..3170ac09c 100644 --- a/spacy/syntax/conll.pyx +++ b/spacy/syntax/conll.pyx @@ -1,6 +1,5 @@ import numpy import codecs -from .ner_util import iob_to_biluo from libc.string cimport memset @@ -47,6 +46,7 @@ def _parse_line(line): label = pieces[7] return id_, word, pos, head_idx, label, iob_ent + cdef class GoldParse: def __init__(self, tokens, annot_tuples): self.mem = Pool() @@ -62,9 +62,12 @@ cdef class GoldParse: self.tags = [None] * len(tokens) self.heads = [-1] * len(tokens) self.labels = ['MISSING'] * len(tokens) - self.ner = [None] * len(tokens) + self.ner = ['O'] * len(tokens) idx_map = {token.idx: token.i for token in tokens} + self.ents = [] + ent_start = None + ent_label = None for idx, tag, head, label, ner in zip(*annot_tuples): if idx < tokens[0].idx: pass @@ -76,8 +79,29 @@ cdef class GoldParse: self.heads[i] = idx_map.get(head, -1) self.labels[i] = label self.tags[i] = tag - self.labels[i] = label - self.ner[i] = ner + if ner == '-': + self.ner[i] = '-' + # Deal with inconsistencies in BILUO arising from tokenization + if ner[0] in ('B', 'U', 'O') and ent_start is not None: + self.ents.append((ent_start, i, ent_label)) + ent_start = None + ent_label = None + if ner[0] in ('B', 'U'): + ent_start = i + ent_label = ner[2:] + if ent_start is not None: + self.ents.append((ent_start, self.length, ent_label)) + for start, end, label in self.ents: + if start == (end - 1): + self.ner[start] = 'U-%s' % label + else: + self.ner[start] = 'B-%s' % label + for i in range(start+1, end-1): + self.ner[i] = 'I-%s' % label + self.ner[end-1] = 'L-%s' % label + + def __len__(self): + return self.length @property def n_non_punct(self): diff --git a/spacy/syntax/ner.pyx b/spacy/syntax/ner.pyx index 3cfe7657b..b3af99da8 100644 --- a/spacy/syntax/ner.pyx +++ b/spacy/syntax/ner.pyx @@ -34,15 +34,14 @@ cdef do_func_t[N_MOVES] do_funcs cdef bint entity_is_open(const State *s) except -1: - return s.sent[s.i - 1].ent.tag >= 1 + return s.ents_len >= 1 and s.ent.end == 0 cdef bint _entity_is_sunk(const State *s, Transition* golds) except -1: if not entity_is_open(s): return False - cdef const Entity* curr = &s.sent[s.i - 1].ent - cdef const Transition* gold = &golds[(s.i - 1) + curr.start] + cdef const Transition* gold = &golds[(s.i - 1) + s.ent.start] if gold.move != BEGIN and gold.move != UNIT: return True elif gold.label != s.ent.label: @@ -52,14 +51,16 @@ cdef bint _entity_is_sunk(const State *s, Transition* golds) except -1: cdef int _is_valid(int act, int label, const State* s) except -1: - if act == BEGIN: - return not entity_is_open(s) + if act == MISSING: + return False + elif act == BEGIN: + return label != 0 and not entity_is_open(s) elif act == IN: - return entity_is_open(s) and s.ent.label == label + return entity_is_open(s) and label != 0 and s.ent.label == label elif act == LAST: - return entity_is_open(s) and s.ent.label == label + return entity_is_open(s) and label != 0 and s.ent.label == label elif act == UNIT: - return not entity_is_open(s) + return label != 0 and not entity_is_open(s) elif act == OUT: return not entity_is_open(s) else: @@ -69,22 +70,34 @@ cdef int _is_valid(int act, int label, const State* s) except -1: cdef class BiluoPushDown(TransitionSystem): @classmethod def get_labels(cls, gold_tuples): - move_labels = {BEGIN: {}, IN: {}, LAST: {}, UNIT: {}, OUT: {'ROOT': True}} - moves = ('-', 'B', 'I', 'L', 'U') - for (raw_text, toks, (ids, tags, heads, labels, iob)) in gold_tuples: - for i, ner_tag in enumerate(iob_to_biluo(iob)): + move_labels = {MISSING: {'ROOT': True}, BEGIN: {}, IN: {}, LAST: {}, UNIT: {}, + OUT: {'ROOT': True}} + moves = ('M', 'B', 'I', 'L', 'U') + for (raw_text, toks, (ids, tags, heads, labels, biluo)) in gold_tuples: + for i, ner_tag in enumerate(biluo): if ner_tag != 'O' and ner_tag != '-': move_str, label = ner_tag.split('-') move_labels[moves.index(move_str)][label] = True return move_labels + def move_name(self, int move, int label): + if move == OUT: + return 'O' + elif move == 'MISSING': + return 'M' + else: + labels = {id_: name for name, id_ in self.label_ids.items()} + return MOVE_NAMES[move] + '-' + labels[label] + cdef int preprocess_gold(self, GoldParse gold) except -1: - biluo_strings = iob_to_biluo(gold.ner) for i in range(gold.length): - gold.c_ner[i] = self.lookup_transition(biluo_strings[i]) + gold.c_ner[i] = self.lookup_transition(gold.ner[i]) cdef Transition lookup_transition(self, object name) except *: - if '-' in name: + if name == '-': + move_str = 'M' + label = 0 + elif '-' in name: move_str, label_str = name.split('-', 1) label = self.label_ids[label_str] else: @@ -107,6 +120,9 @@ cdef class BiluoPushDown(TransitionSystem): t.get_cost = _get_cost return t + cdef int first_state(self, State* state) except -1: + pass + cdef Transition best_valid(self, const weight_t* scores, const State* s) except *: cdef int best = -1 cdef weight_t score = -90000 @@ -128,8 +144,9 @@ cdef int _get_cost(const Transition* self, const State* s, GoldParse gold) excep return 9000 cdef bint is_sunk = _entity_is_sunk(s, gold.c_ner) cdef int next_act = gold.c_ner[s.i+1].move if s.i < s.sent_len else OUT - return not _is_gold(self.move, self.label, gold.c_ner[s.i].move, gold.c_ner[s.i].label, - next_act, is_sunk) + cdef bint is_gold = _is_gold(self.move, self.label, gold.c_ner[s.i].move, + gold.c_ner[s.i].label, next_act, is_sunk) + return not is_gold cdef bint _is_gold(int act, int tag, int g_act, int g_tag, int next_act, bint is_sunk): @@ -210,18 +227,21 @@ cdef int _do_begin(const Transition* self, State* s) except -1: s.ents_len += 1 s.ent.start = s.i s.ent.label = self.label - s.sent[s.i].ent.tag = self.clas + s.sent[s.i].ent_iob = 3 + s.sent[s.i].ent_type = self.label s.i += 1 cdef int _do_in(const Transition* self, State* s) except -1: - s.sent[s.i].ent.tag = self.clas + s.sent[s.i].ent_iob = 1 + s.sent[s.i].ent_type = self.label s.i += 1 cdef int _do_last(const Transition* self, State* s) except -1: s.ent.end = s.i+1 - s.sent[s.i].ent.tag = self.clas + s.sent[s.i].ent_iob = 1 + s.sent[s.i].ent_type = self.label s.i += 1 @@ -231,12 +251,13 @@ cdef int _do_unit(const Transition* self, State* s) except -1: s.ent.start = s.i s.ent.label = self.label s.ent.end = s.i+1 - s.sent[s.i].ent.tag = self.clas + s.sent[s.i].ent_iob = 3 + s.sent[s.i].ent_type = self.label s.i += 1 cdef int _do_out(const Transition* self, State* s) except -1: - s.sent[s.i].ent.tag = self.clas + s.sent[s.i].ent_iob = 2 s.i += 1 diff --git a/spacy/syntax/parser.pyx b/spacy/syntax/parser.pyx index 6360f4f8b..a18f4b5b7 100644 --- a/spacy/syntax/parser.pyx +++ b/spacy/syntax/parser.pyx @@ -28,14 +28,12 @@ from ..tokens cimport Tokens, TokenC from .arc_eager cimport TransitionSystem, Transition from .transition_system import OracleError -from ._state cimport init_state, State, is_final, get_idx, get_s0, get_s1, get_n0, get_n1 +from ._state cimport new_state, State, is_final, get_idx, get_s0, get_s1, get_n0, get_n1 from .conll cimport GoldParse from . import _parse_features from ._parse_features cimport fill_context, CONTEXT_SIZE -from ._ner_features cimport _ner_features - DEBUG = False def set_debug(val): @@ -50,7 +48,11 @@ cdef unicode print_state(State* s, list words): third = words[s.stack[-2]] + '_%d' % s.sent[s.stack[-2]].head n0 = words[s.i] n1 = words[s.i + 1] - return ' '.join((str(s.stack_len), third, second, top, '|', n0, n1)) + if s.ents_len: + ent = '%s %d-%d' % (s.ent.label, s.ent.start, s.ent.end) + else: + ent = '-' + return ' '.join((ent, str(s.stack_len), third, second, top, '|', n0, n1)) def get_templates(name): @@ -58,7 +60,7 @@ def get_templates(name): if name == 'zhang': return pf.arc_eager elif name == 'ner': - return _ner_features.basic + return pf.ner else: return (pf.unigrams + pf.s0_n0 + pf.s1_n0 + pf.s0_n1 + pf.n0_n1 + \ pf.tree_shape + pf.trigrams) @@ -79,7 +81,8 @@ cdef class GreedyParser: cdef atom_t[CONTEXT_SIZE] context cdef int n_feats cdef Pool mem = Pool() - cdef State* state = init_state(mem, tokens.data, tokens.length) + cdef State* state = new_state(mem, tokens.data, tokens.length) + self.moves.first_state(state) cdef Transition guess while not is_final(state): fill_context(context, state) @@ -99,10 +102,12 @@ cdef class GreedyParser: Transition best atom_t[CONTEXT_SIZE] context - + self.moves.preprocess_gold(gold) cdef Pool mem = Pool() - cdef State* state = init_state(mem, tokens.data, tokens.length) + cdef State* state = new_state(mem, tokens.data, tokens.length) + self.moves.first_state(state) + py_words = [t.orth_ for t in tokens] while not is_final(state): fill_context(context, state) scores = self.model.score(context) @@ -114,7 +119,3 @@ cdef class GreedyParser: best.do(&best, state) else: guess.do(&guess, state) - n_corr = gold.heads_correct(state.sent, score_punct=True) - if force_gold and n_corr != tokens.length: - raise OracleError - return n_corr diff --git a/spacy/syntax/transition_system.pxd b/spacy/syntax/transition_system.pxd index 2b02e1e03..35eba0bd7 100644 --- a/spacy/syntax/transition_system.pxd +++ b/spacy/syntax/transition_system.pxd @@ -29,6 +29,8 @@ cdef class TransitionSystem: cdef const Transition* c cdef readonly int n_moves + cdef int first_state(self, State* state) except -1 + cdef int preprocess_gold(self, GoldParse gold) except -1 cdef Transition lookup_transition(self, object name) except * diff --git a/spacy/syntax/transition_system.pyx b/spacy/syntax/transition_system.pyx index ae9591690..5c3a6fd72 100644 --- a/spacy/syntax/transition_system.pyx +++ b/spacy/syntax/transition_system.pyx @@ -28,6 +28,9 @@ cdef class TransitionSystem: self.label_ids['MISSING'] = -1 self.c = moves + cdef int first_state(self, State* state) except -1: + raise NotImplementedError + cdef int preprocess_gold(self, GoldParse gold) except -1: raise NotImplementedError diff --git a/spacy/tokens.pxd b/spacy/tokens.pxd index 7c102f973..89902d36d 100644 --- a/spacy/tokens.pxd +++ b/spacy/tokens.pxd @@ -39,6 +39,7 @@ cdef class Tokens: cdef unicode _string cdef tuple _tag_strings cdef tuple _dep_strings + cdef public tuple _ent_strings cdef public bint is_tagged cdef public bint is_parsed diff --git a/spacy/tokens.pyx b/spacy/tokens.pyx index 31dadc684..2ba855554 100644 --- a/spacy/tokens.pyx +++ b/spacy/tokens.pyx @@ -94,6 +94,7 @@ cdef class Tokens: self._py_tokens = [] self._tag_strings = tuple() # These will be set by the POS tagger and parser self._dep_strings = tuple() # The strings are arbitrary and model-specific. + self._ent_strings = tuple() # TODO: Clean this up def __getitem__(self, object i): """Retrieve a token. @@ -129,6 +130,28 @@ cdef class Tokens: cdef const TokenC* last = &self.data[self.length - 1] return self._string[:last.idx + last.lex.length] + property ents: + def __get__(self): + cdef int i + cdef const TokenC* token + cdef int start = -1 + cdef object label = None + for i in range(self.length): + token = &self.data[i] + if token.ent_iob == 1: + assert start != -1 + pass + elif token.ent_iob == 2: + if start != -1: + yield (start, i, label) + start = -1 + label = None + elif token.ent_iob == 3: + start = i + label = self._ent_strings[token.ent_type] + if start != -1: + yield (start, self.length, label) + cdef int push_back(self, int idx, LexemeOrToken lex_or_tok) except -1: if self.length == self.max_length: self._realloc(self.length * 2)