* Use StringStore to encode label names, instead of label_ids

This commit is contained in:
Matthew Honnibal 2015-03-14 11:06:35 -04:00
parent 64db61bff1
commit 31fad99518
6 changed files with 28 additions and 31 deletions

View File

@ -56,7 +56,7 @@ cdef class ArcEager(TransitionSystem):
cdef int preprocess_gold(self, GoldParse gold) except -1: cdef int preprocess_gold(self, GoldParse gold) except -1:
for i in range(gold.length): for i in range(gold.length):
gold.c_heads[i] = gold.heads[i] gold.c_heads[i] = gold.heads[i]
gold.c_labels[i] = self.label_ids[gold.labels[i]] gold.c_labels[i] = self.strings[gold.labels[i]]
cdef Transition lookup_transition(self, object name) except *: cdef Transition lookup_transition(self, object name) except *:
if '-' in name: if '-' in name:

View File

@ -24,6 +24,7 @@ from thinc.features cimport count_feats
from thinc.learner cimport LinearModel from thinc.learner cimport LinearModel
from ..tokens cimport Tokens, TokenC from ..tokens cimport Tokens, TokenC
from ..strings cimport StringStore
from .arc_eager cimport TransitionSystem, Transition from .arc_eager cimport TransitionSystem, Transition
from .transition_system import OracleError from .transition_system import OracleError
@ -67,10 +68,10 @@ def get_templates(name):
cdef class GreedyParser: cdef class GreedyParser:
def __init__(self, model_dir, transition_system): def __init__(self, StringStore strings, model_dir, transition_system):
assert os.path.exists(model_dir) and os.path.isdir(model_dir) assert os.path.exists(model_dir) and os.path.isdir(model_dir)
self.cfg = Config.read(model_dir, 'config') self.cfg = Config.read(model_dir, 'config')
self.moves = transition_system(self.cfg.labels) self.moves = transition_system(strings, self.cfg.labels)
templates = get_templates(self.cfg.features) templates = get_templates(self.cfg.features)
self.model = Model(self.moves.n_moves, templates, model_dir) self.model = Model(self.moves.n_moves, templates, model_dir)
@ -89,7 +90,7 @@ cdef class GreedyParser:
scores = self.model.score(context) scores = self.model.score(context)
guess = self.moves.best_valid(scores, state) guess = self.moves.best_valid(scores, state)
guess.do(&guess, state) guess.do(&guess, state)
tokens.set_parse(state.sent, self.moves.label_ids) tokens.set_parse(state.sent)
return 0 return 0
def train(self, Tokens tokens, GoldParse gold, force_gold=False): def train(self, Tokens tokens, GoldParse gold, force_gold=False):

View File

@ -4,6 +4,7 @@ from thinc.typedefs cimport weight_t
from ..structs cimport TokenC from ..structs cimport TokenC
from ._state cimport State from ._state cimport State
from .conll cimport GoldParse from .conll cimport GoldParse
from ..strings cimport StringStore
cdef struct Transition: cdef struct Transition:
@ -24,8 +25,8 @@ ctypedef int (*do_func_t)(const Transition* self, State* state) except -1
cdef class TransitionSystem: cdef class TransitionSystem:
cdef readonly dict label_ids
cdef Pool mem cdef Pool mem
cdef StringStore strings
cdef const Transition* c cdef const Transition* c
cdef readonly int n_moves cdef readonly int n_moves

View File

@ -12,20 +12,18 @@ class OracleError(Exception):
cdef class TransitionSystem: cdef class TransitionSystem:
def __init__(self, dict labels_by_action): def __init__(self, StringStore string_table, dict labels_by_action):
self.mem = Pool() self.mem = Pool()
self.n_moves = sum(len(labels) for labels in labels_by_action.values()) self.n_moves = sum(len(labels) for labels in labels_by_action.values())
moves = <Transition*>self.mem.alloc(self.n_moves, sizeof(Transition)) moves = <Transition*>self.mem.alloc(self.n_moves, sizeof(Transition))
cdef int i = 0 cdef int i = 0
cdef int label_id cdef int label_id
self.label_ids = {'ROOT': 0} self.strings = string_table
for action, label_strs in sorted(labels_by_action.items()): for action, label_strs in sorted(labels_by_action.items()):
for label_str in sorted(label_strs): for label_str in sorted(label_strs):
label_str = unicode(label_str) label_id = self.strings[unicode(label_str)]
label_id = self.label_ids.setdefault(label_str, len(self.label_ids))
moves[i] = self.init_transition(i, int(action), label_id) moves[i] = self.init_transition(i, int(action), label_id)
i += 1 i += 1
self.label_ids['MISSING'] = -1
self.c = moves self.c = moves
cdef int first_state(self, State* state) except -1: cdef int first_state(self, State* state) except -1:

View File

@ -38,8 +38,6 @@ cdef class Tokens:
cdef list _py_tokens cdef list _py_tokens
cdef unicode _string cdef unicode _string
cdef tuple _tag_strings cdef tuple _tag_strings
cdef tuple _dep_strings
cdef public tuple _ent_strings
cdef public bint is_tagged cdef public bint is_tagged
cdef public bint is_parsed cdef public bint is_parsed
@ -71,13 +69,11 @@ cdef class Token:
cdef Tokens _seq cdef Tokens _seq
cdef tuple _tag_strings
cdef tuple _dep_strings
@staticmethod @staticmethod
cdef inline Token cinit(Vocab vocab, unicode string, cdef inline Token cinit(Vocab vocab, unicode string,
const TokenC* token, int offset, int array_len, const TokenC* token, int offset, int array_len,
Tokens parent_seq, tuple tag_strings, tuple dep_strings): Tokens parent_seq, self._tag_strings):
if offset < 0 or offset >= array_len: if offset < 0 or offset >= array_len:
msg = "Attempt to access token at %d, max length %d" msg = "Attempt to access token at %d, max length %d"
@ -92,8 +88,6 @@ cdef class Token:
self.array_len = array_len self.array_len = array_len
self._seq = parent_seq self._seq = parent_seq
self._tag_strings = tag_strings
self._dep_strings = dep_strings
self._seq._py_tokens[offset] = self self._seq._py_tokens[offset] = self
return self return self

View File

@ -93,8 +93,6 @@ cdef class Tokens:
self.is_parsed = False self.is_parsed = False
self._py_tokens = [] self._py_tokens = []
self._tag_strings = tuple() # These will be set by the POS tagger and parser self._tag_strings = tuple() # These will be set by the POS tagger and parser
self._dep_strings = tuple() # The strings are arbitrary and model-specific.
self._ent_strings = tuple() # TODO: Clean this up
def __getitem__(self, object i): def __getitem__(self, object i):
"""Retrieve a token. """Retrieve a token.
@ -110,7 +108,7 @@ cdef class Tokens:
bounds_check(i, self.length, PADDING) bounds_check(i, self.length, PADDING)
return Token.cinit(self.vocab, self._string, return Token.cinit(self.vocab, self._string,
&self.data[i], i, self.length, &self.data[i], i, self.length,
self, self._tag_strings, self._dep_strings) self, self._tag_strings)
def __iter__(self): def __iter__(self):
"""Iterate over the tokens. """Iterate over the tokens.
@ -121,7 +119,7 @@ cdef class Tokens:
for i in range(self.length): for i in range(self.length):
yield Token.cinit(self.vocab, self._string, yield Token.cinit(self.vocab, self._string,
&self.data[i], i, self.length, &self.data[i], i, self.length,
self, self._tag_strings, self._dep_strings) self, self._tag_strings)
def __len__(self): def __len__(self):
return self.length return self.length
@ -148,7 +146,7 @@ cdef class Tokens:
label = None label = None
elif token.ent_iob == 3: elif token.ent_iob == 3:
start = i start = i
label = self._ent_strings[token.ent_type] label = self.vocab.strings[token.ent_type]
if start != -1: if start != -1:
yield (start, self.length, label) yield (start, self.length, label)
@ -252,10 +250,6 @@ cdef class Tokens:
self.is_parsed = True self.is_parsed = True
for i in range(self.length): for i in range(self.length):
self.data[i] = parsed[i] self.data[i] = parsed[i]
dep_strings = [None] * len(label_ids)
for dep_string, id_ in label_ids.items():
dep_strings[id_] = dep_string
self._dep_strings = tuple(dep_strings)
cdef class Span: cdef class Span:
@ -265,6 +259,15 @@ cdef class Span:
self.start = start self.start = start
self.end = end self.end = end
def __richcmp__(self, Span other, int op):
# Eq
if op in (1, 2, 5):
if self._seq is other._seq and \
self.start == other.start and \
self.end == other.end:
return True
return False
def __len__(self): def __len__(self):
if self.end < self.start: if self.end < self.start:
return 0 return 0
@ -310,7 +313,7 @@ cdef class Token:
def nbor(self, int i=1): def nbor(self, int i=1):
return Token.cinit(self.vocab, self._string, return Token.cinit(self.vocab, self._string,
self.c, self.i, self.array_len, self.c, self.i, self.array_len,
self._seq, self._tag_strings, self._dep_strings) self._seq, self._tag_strings)
property string: property string:
def __get__(self): def __get__(self):
@ -411,7 +414,7 @@ cdef class Token:
elif ptr + ptr.head == self.c: elif ptr + ptr.head == self.c:
yield Token.cinit(self.vocab, self._string, yield Token.cinit(self.vocab, self._string,
ptr, ptr - (self.c - self.i), self.array_len, ptr, ptr - (self.c - self.i), self.array_len,
self._seq, self._tag_strings, self._dep_strings) self._seq, self._tag_strings)
ptr += 1 ptr += 1
else: else:
ptr += 1 ptr += 1
@ -430,7 +433,7 @@ cdef class Token:
elif ptr + ptr.head == self.c: elif ptr + ptr.head == self.c:
yield Token.cinit(self.vocab, self._string, yield Token.cinit(self.vocab, self._string,
ptr, ptr - (self.c - self.i), self.array_len, ptr, ptr - (self.c - self.i), self.array_len,
self._seq, self._tag_strings, self._dep_strings) self._seq, self._tag_strings)
ptr -= 1 ptr -= 1
else: else:
ptr -= 1 ptr -= 1
@ -453,7 +456,7 @@ cdef class Token:
"""The token predicted by the parser to be the head of the current token.""" """The token predicted by the parser to be the head of the current token."""
return Token.cinit(self.vocab, self._string, return Token.cinit(self.vocab, self._string,
self.c + self.c.head, self.i + self.c.head, self.array_len, self.c + self.c.head, self.i + self.c.head, self.array_len,
self._seq, self._tag_strings, self._dep_strings) self._seq, self._tag_strings)
property whitespace_: property whitespace_:
def __get__(self): def __get__(self):
@ -497,7 +500,7 @@ cdef class Token:
property dep_: property dep_:
def __get__(self): def __get__(self):
return self._dep_strings[self.c.dep] return self.vocab.strings[self.c.dep]
_pos_id_to_string = {id_: string for string, id_ in UNIV_POS_NAMES.items()} _pos_id_to_string = {id_: string for string, id_ in UNIV_POS_NAMES.items()}