* Use StringStore to encode label names, instead of label_ids

This commit is contained in:
Matthew Honnibal 2015-03-14 11:06:35 -04:00
parent 64db61bff1
commit 31fad99518
6 changed files with 28 additions and 31 deletions

View File

@ -56,7 +56,7 @@ cdef class ArcEager(TransitionSystem):
cdef int preprocess_gold(self, GoldParse gold) except -1:
for i in range(gold.length):
gold.c_heads[i] = gold.heads[i]
gold.c_labels[i] = self.label_ids[gold.labels[i]]
gold.c_labels[i] = self.strings[gold.labels[i]]
cdef Transition lookup_transition(self, object name) except *:
if '-' in name:

View File

@ -24,6 +24,7 @@ from thinc.features cimport count_feats
from thinc.learner cimport LinearModel
from ..tokens cimport Tokens, TokenC
from ..strings cimport StringStore
from .arc_eager cimport TransitionSystem, Transition
from .transition_system import OracleError
@ -67,10 +68,10 @@ def get_templates(name):
cdef class GreedyParser:
def __init__(self, model_dir, transition_system):
def __init__(self, StringStore strings, model_dir, transition_system):
assert os.path.exists(model_dir) and os.path.isdir(model_dir)
self.cfg = Config.read(model_dir, 'config')
self.moves = transition_system(self.cfg.labels)
self.moves = transition_system(strings, self.cfg.labels)
templates = get_templates(self.cfg.features)
self.model = Model(self.moves.n_moves, templates, model_dir)
@ -89,7 +90,7 @@ cdef class GreedyParser:
scores = self.model.score(context)
guess = self.moves.best_valid(scores, state)
guess.do(&guess, state)
tokens.set_parse(state.sent, self.moves.label_ids)
tokens.set_parse(state.sent)
return 0
def train(self, Tokens tokens, GoldParse gold, force_gold=False):

View File

@ -4,6 +4,7 @@ from thinc.typedefs cimport weight_t
from ..structs cimport TokenC
from ._state cimport State
from .conll cimport GoldParse
from ..strings cimport StringStore
cdef struct Transition:
@ -24,8 +25,8 @@ ctypedef int (*do_func_t)(const Transition* self, State* state) except -1
cdef class TransitionSystem:
cdef readonly dict label_ids
cdef Pool mem
cdef StringStore strings
cdef const Transition* c
cdef readonly int n_moves

View File

@ -12,20 +12,18 @@ class OracleError(Exception):
cdef class TransitionSystem:
def __init__(self, dict labels_by_action):
def __init__(self, StringStore string_table, dict labels_by_action):
self.mem = Pool()
self.n_moves = sum(len(labels) for labels in labels_by_action.values())
moves = <Transition*>self.mem.alloc(self.n_moves, sizeof(Transition))
cdef int i = 0
cdef int label_id
self.label_ids = {'ROOT': 0}
self.strings = string_table
for action, label_strs in sorted(labels_by_action.items()):
for label_str in sorted(label_strs):
label_str = unicode(label_str)
label_id = self.label_ids.setdefault(label_str, len(self.label_ids))
label_id = self.strings[unicode(label_str)]
moves[i] = self.init_transition(i, int(action), label_id)
i += 1
self.label_ids['MISSING'] = -1
self.c = moves
cdef int first_state(self, State* state) except -1:

View File

@ -38,8 +38,6 @@ cdef class Tokens:
cdef list _py_tokens
cdef unicode _string
cdef tuple _tag_strings
cdef tuple _dep_strings
cdef public tuple _ent_strings
cdef public bint is_tagged
cdef public bint is_parsed
@ -71,13 +69,11 @@ cdef class Token:
cdef Tokens _seq
cdef tuple _tag_strings
cdef tuple _dep_strings
@staticmethod
cdef inline Token cinit(Vocab vocab, unicode string,
const TokenC* token, int offset, int array_len,
Tokens parent_seq, tuple tag_strings, tuple dep_strings):
Tokens parent_seq, self._tag_strings):
if offset < 0 or offset >= array_len:
msg = "Attempt to access token at %d, max length %d"
@ -92,8 +88,6 @@ cdef class Token:
self.array_len = array_len
self._seq = parent_seq
self._tag_strings = tag_strings
self._dep_strings = dep_strings
self._seq._py_tokens[offset] = self
return self

View File

@ -93,8 +93,6 @@ cdef class Tokens:
self.is_parsed = False
self._py_tokens = []
self._tag_strings = tuple() # These will be set by the POS tagger and parser
self._dep_strings = tuple() # The strings are arbitrary and model-specific.
self._ent_strings = tuple() # TODO: Clean this up
def __getitem__(self, object i):
"""Retrieve a token.
@ -110,7 +108,7 @@ cdef class Tokens:
bounds_check(i, self.length, PADDING)
return Token.cinit(self.vocab, self._string,
&self.data[i], i, self.length,
self, self._tag_strings, self._dep_strings)
self, self._tag_strings)
def __iter__(self):
"""Iterate over the tokens.
@ -121,7 +119,7 @@ cdef class Tokens:
for i in range(self.length):
yield Token.cinit(self.vocab, self._string,
&self.data[i], i, self.length,
self, self._tag_strings, self._dep_strings)
self, self._tag_strings)
def __len__(self):
return self.length
@ -148,7 +146,7 @@ cdef class Tokens:
label = None
elif token.ent_iob == 3:
start = i
label = self._ent_strings[token.ent_type]
label = self.vocab.strings[token.ent_type]
if start != -1:
yield (start, self.length, label)
@ -252,10 +250,6 @@ cdef class Tokens:
self.is_parsed = True
for i in range(self.length):
self.data[i] = parsed[i]
dep_strings = [None] * len(label_ids)
for dep_string, id_ in label_ids.items():
dep_strings[id_] = dep_string
self._dep_strings = tuple(dep_strings)
cdef class Span:
@ -265,6 +259,15 @@ cdef class Span:
self.start = start
self.end = end
def __richcmp__(self, Span other, int op):
# Eq
if op in (1, 2, 5):
if self._seq is other._seq and \
self.start == other.start and \
self.end == other.end:
return True
return False
def __len__(self):
if self.end < self.start:
return 0
@ -310,7 +313,7 @@ cdef class Token:
def nbor(self, int i=1):
return Token.cinit(self.vocab, self._string,
self.c, self.i, self.array_len,
self._seq, self._tag_strings, self._dep_strings)
self._seq, self._tag_strings)
property string:
def __get__(self):
@ -411,7 +414,7 @@ cdef class Token:
elif ptr + ptr.head == self.c:
yield Token.cinit(self.vocab, self._string,
ptr, ptr - (self.c - self.i), self.array_len,
self._seq, self._tag_strings, self._dep_strings)
self._seq, self._tag_strings)
ptr += 1
else:
ptr += 1
@ -430,7 +433,7 @@ cdef class Token:
elif ptr + ptr.head == self.c:
yield Token.cinit(self.vocab, self._string,
ptr, ptr - (self.c - self.i), self.array_len,
self._seq, self._tag_strings, self._dep_strings)
self._seq, self._tag_strings)
ptr -= 1
else:
ptr -= 1
@ -453,7 +456,7 @@ cdef class Token:
"""The token predicted by the parser to be the head of the current token."""
return Token.cinit(self.vocab, self._string,
self.c + self.c.head, self.i + self.c.head, self.array_len,
self._seq, self._tag_strings, self._dep_strings)
self._seq, self._tag_strings)
property whitespace_:
def __get__(self):
@ -497,7 +500,7 @@ cdef class Token:
property dep_:
def __get__(self):
return self._dep_strings[self.c.dep]
return self.vocab.strings[self.c.dep]
_pos_id_to_string = {id_: string for string, id_ in UNIV_POS_NAMES.items()}