diff --git a/spacy/syntax/arc_eager.pyx b/spacy/syntax/arc_eager.pyx index b705086e8..e5a7406f6 100644 --- a/spacy/syntax/arc_eager.pyx +++ b/spacy/syntax/arc_eager.pyx @@ -56,7 +56,7 @@ cdef class ArcEager(TransitionSystem): cdef int preprocess_gold(self, GoldParse gold) except -1: for i in range(gold.length): gold.c_heads[i] = gold.heads[i] - gold.c_labels[i] = self.label_ids[gold.labels[i]] + gold.c_labels[i] = self.strings[gold.labels[i]] cdef Transition lookup_transition(self, object name) except *: if '-' in name: diff --git a/spacy/syntax/parser.pyx b/spacy/syntax/parser.pyx index 45bb61103..4a925dd8c 100644 --- a/spacy/syntax/parser.pyx +++ b/spacy/syntax/parser.pyx @@ -24,6 +24,7 @@ from thinc.features cimport count_feats from thinc.learner cimport LinearModel from ..tokens cimport Tokens, TokenC +from ..strings cimport StringStore from .arc_eager cimport TransitionSystem, Transition from .transition_system import OracleError @@ -67,10 +68,10 @@ def get_templates(name): cdef class GreedyParser: - def __init__(self, model_dir, transition_system): + def __init__(self, StringStore strings, model_dir, transition_system): assert os.path.exists(model_dir) and os.path.isdir(model_dir) self.cfg = Config.read(model_dir, 'config') - self.moves = transition_system(self.cfg.labels) + self.moves = transition_system(strings, self.cfg.labels) templates = get_templates(self.cfg.features) self.model = Model(self.moves.n_moves, templates, model_dir) @@ -89,7 +90,7 @@ cdef class GreedyParser: scores = self.model.score(context) guess = self.moves.best_valid(scores, state) guess.do(&guess, state) - tokens.set_parse(state.sent, self.moves.label_ids) + tokens.set_parse(state.sent) return 0 def train(self, Tokens tokens, GoldParse gold, force_gold=False): diff --git a/spacy/syntax/transition_system.pxd b/spacy/syntax/transition_system.pxd index 35eba0bd7..58aa90d99 100644 --- a/spacy/syntax/transition_system.pxd +++ b/spacy/syntax/transition_system.pxd @@ -4,6 +4,7 @@ from thinc.typedefs cimport weight_t from ..structs cimport TokenC from ._state cimport State from .conll cimport GoldParse +from ..strings cimport StringStore cdef struct Transition: @@ -24,8 +25,8 @@ ctypedef int (*do_func_t)(const Transition* self, State* state) except -1 cdef class TransitionSystem: - cdef readonly dict label_ids cdef Pool mem + cdef StringStore strings cdef const Transition* c cdef readonly int n_moves diff --git a/spacy/syntax/transition_system.pyx b/spacy/syntax/transition_system.pyx index 5c3a6fd72..72e9cedf8 100644 --- a/spacy/syntax/transition_system.pyx +++ b/spacy/syntax/transition_system.pyx @@ -12,20 +12,18 @@ class OracleError(Exception): cdef class TransitionSystem: - def __init__(self, dict labels_by_action): + def __init__(self, StringStore string_table, dict labels_by_action): self.mem = Pool() self.n_moves = sum(len(labels) for labels in labels_by_action.values()) moves = self.mem.alloc(self.n_moves, sizeof(Transition)) cdef int i = 0 cdef int label_id - self.label_ids = {'ROOT': 0} + self.strings = string_table for action, label_strs in sorted(labels_by_action.items()): for label_str in sorted(label_strs): - label_str = unicode(label_str) - label_id = self.label_ids.setdefault(label_str, len(self.label_ids)) + label_id = self.strings[unicode(label_str)] moves[i] = self.init_transition(i, int(action), label_id) i += 1 - self.label_ids['MISSING'] = -1 self.c = moves cdef int first_state(self, State* state) except -1: diff --git a/spacy/tokens.pxd b/spacy/tokens.pxd index efc03c368..169ed0b1b 100644 --- a/spacy/tokens.pxd +++ b/spacy/tokens.pxd @@ -38,8 +38,6 @@ cdef class Tokens: cdef list _py_tokens cdef unicode _string cdef tuple _tag_strings - cdef tuple _dep_strings - cdef public tuple _ent_strings cdef public bint is_tagged cdef public bint is_parsed @@ -71,13 +69,11 @@ cdef class Token: cdef Tokens _seq - cdef tuple _tag_strings - cdef tuple _dep_strings @staticmethod cdef inline Token cinit(Vocab vocab, unicode string, const TokenC* token, int offset, int array_len, - Tokens parent_seq, tuple tag_strings, tuple dep_strings): + Tokens parent_seq, self._tag_strings): if offset < 0 or offset >= array_len: msg = "Attempt to access token at %d, max length %d" @@ -92,8 +88,6 @@ cdef class Token: self.array_len = array_len self._seq = parent_seq - self._tag_strings = tag_strings - self._dep_strings = dep_strings self._seq._py_tokens[offset] = self return self diff --git a/spacy/tokens.pyx b/spacy/tokens.pyx index 677f7a101..696e14ea3 100644 --- a/spacy/tokens.pyx +++ b/spacy/tokens.pyx @@ -93,8 +93,6 @@ cdef class Tokens: self.is_parsed = False self._py_tokens = [] self._tag_strings = tuple() # These will be set by the POS tagger and parser - self._dep_strings = tuple() # The strings are arbitrary and model-specific. - self._ent_strings = tuple() # TODO: Clean this up def __getitem__(self, object i): """Retrieve a token. @@ -110,7 +108,7 @@ cdef class Tokens: bounds_check(i, self.length, PADDING) return Token.cinit(self.vocab, self._string, &self.data[i], i, self.length, - self, self._tag_strings, self._dep_strings) + self, self._tag_strings) def __iter__(self): """Iterate over the tokens. @@ -121,7 +119,7 @@ cdef class Tokens: for i in range(self.length): yield Token.cinit(self.vocab, self._string, &self.data[i], i, self.length, - self, self._tag_strings, self._dep_strings) + self, self._tag_strings) def __len__(self): return self.length @@ -148,7 +146,7 @@ cdef class Tokens: label = None elif token.ent_iob == 3: start = i - label = self._ent_strings[token.ent_type] + label = self.vocab.strings[token.ent_type] if start != -1: yield (start, self.length, label) @@ -252,10 +250,6 @@ cdef class Tokens: self.is_parsed = True for i in range(self.length): self.data[i] = parsed[i] - dep_strings = [None] * len(label_ids) - for dep_string, id_ in label_ids.items(): - dep_strings[id_] = dep_string - self._dep_strings = tuple(dep_strings) cdef class Span: @@ -265,6 +259,15 @@ cdef class Span: self.start = start self.end = end + def __richcmp__(self, Span other, int op): + # Eq + if op in (1, 2, 5): + if self._seq is other._seq and \ + self.start == other.start and \ + self.end == other.end: + return True + return False + def __len__(self): if self.end < self.start: return 0 @@ -310,7 +313,7 @@ cdef class Token: def nbor(self, int i=1): return Token.cinit(self.vocab, self._string, self.c, self.i, self.array_len, - self._seq, self._tag_strings, self._dep_strings) + self._seq, self._tag_strings) property string: def __get__(self): @@ -411,7 +414,7 @@ cdef class Token: elif ptr + ptr.head == self.c: yield Token.cinit(self.vocab, self._string, ptr, ptr - (self.c - self.i), self.array_len, - self._seq, self._tag_strings, self._dep_strings) + self._seq, self._tag_strings) ptr += 1 else: ptr += 1 @@ -430,7 +433,7 @@ cdef class Token: elif ptr + ptr.head == self.c: yield Token.cinit(self.vocab, self._string, ptr, ptr - (self.c - self.i), self.array_len, - self._seq, self._tag_strings, self._dep_strings) + self._seq, self._tag_strings) ptr -= 1 else: ptr -= 1 @@ -453,7 +456,7 @@ cdef class Token: """The token predicted by the parser to be the head of the current token.""" return Token.cinit(self.vocab, self._string, self.c + self.c.head, self.i + self.c.head, self.array_len, - self._seq, self._tag_strings, self._dep_strings) + self._seq, self._tag_strings) property whitespace_: def __get__(self): @@ -497,7 +500,7 @@ cdef class Token: property dep_: def __get__(self): - return self._dep_strings[self.c.dep] + return self.vocab.strings[self.c.dep] _pos_id_to_string = {id_: string for string, id_ in UNIV_POS_NAMES.items()}