mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
* Use StringStore to encode label names, instead of label_ids
This commit is contained in:
parent
64db61bff1
commit
31fad99518
|
@ -56,7 +56,7 @@ cdef class ArcEager(TransitionSystem):
|
|||
cdef int preprocess_gold(self, GoldParse gold) except -1:
|
||||
for i in range(gold.length):
|
||||
gold.c_heads[i] = gold.heads[i]
|
||||
gold.c_labels[i] = self.label_ids[gold.labels[i]]
|
||||
gold.c_labels[i] = self.strings[gold.labels[i]]
|
||||
|
||||
cdef Transition lookup_transition(self, object name) except *:
|
||||
if '-' in name:
|
||||
|
|
|
@ -24,6 +24,7 @@ from thinc.features cimport count_feats
|
|||
from thinc.learner cimport LinearModel
|
||||
|
||||
from ..tokens cimport Tokens, TokenC
|
||||
from ..strings cimport StringStore
|
||||
|
||||
from .arc_eager cimport TransitionSystem, Transition
|
||||
from .transition_system import OracleError
|
||||
|
@ -67,10 +68,10 @@ def get_templates(name):
|
|||
|
||||
|
||||
cdef class GreedyParser:
|
||||
def __init__(self, model_dir, transition_system):
|
||||
def __init__(self, StringStore strings, model_dir, transition_system):
|
||||
assert os.path.exists(model_dir) and os.path.isdir(model_dir)
|
||||
self.cfg = Config.read(model_dir, 'config')
|
||||
self.moves = transition_system(self.cfg.labels)
|
||||
self.moves = transition_system(strings, self.cfg.labels)
|
||||
templates = get_templates(self.cfg.features)
|
||||
self.model = Model(self.moves.n_moves, templates, model_dir)
|
||||
|
||||
|
@ -89,7 +90,7 @@ cdef class GreedyParser:
|
|||
scores = self.model.score(context)
|
||||
guess = self.moves.best_valid(scores, state)
|
||||
guess.do(&guess, state)
|
||||
tokens.set_parse(state.sent, self.moves.label_ids)
|
||||
tokens.set_parse(state.sent)
|
||||
return 0
|
||||
|
||||
def train(self, Tokens tokens, GoldParse gold, force_gold=False):
|
||||
|
|
|
@ -4,6 +4,7 @@ from thinc.typedefs cimport weight_t
|
|||
from ..structs cimport TokenC
|
||||
from ._state cimport State
|
||||
from .conll cimport GoldParse
|
||||
from ..strings cimport StringStore
|
||||
|
||||
|
||||
cdef struct Transition:
|
||||
|
@ -24,8 +25,8 @@ ctypedef int (*do_func_t)(const Transition* self, State* state) except -1
|
|||
|
||||
|
||||
cdef class TransitionSystem:
|
||||
cdef readonly dict label_ids
|
||||
cdef Pool mem
|
||||
cdef StringStore strings
|
||||
cdef const Transition* c
|
||||
cdef readonly int n_moves
|
||||
|
||||
|
|
|
@ -12,20 +12,18 @@ class OracleError(Exception):
|
|||
|
||||
|
||||
cdef class TransitionSystem:
|
||||
def __init__(self, dict labels_by_action):
|
||||
def __init__(self, StringStore string_table, dict labels_by_action):
|
||||
self.mem = Pool()
|
||||
self.n_moves = sum(len(labels) for labels in labels_by_action.values())
|
||||
moves = <Transition*>self.mem.alloc(self.n_moves, sizeof(Transition))
|
||||
cdef int i = 0
|
||||
cdef int label_id
|
||||
self.label_ids = {'ROOT': 0}
|
||||
self.strings = string_table
|
||||
for action, label_strs in sorted(labels_by_action.items()):
|
||||
for label_str in sorted(label_strs):
|
||||
label_str = unicode(label_str)
|
||||
label_id = self.label_ids.setdefault(label_str, len(self.label_ids))
|
||||
label_id = self.strings[unicode(label_str)]
|
||||
moves[i] = self.init_transition(i, int(action), label_id)
|
||||
i += 1
|
||||
self.label_ids['MISSING'] = -1
|
||||
self.c = moves
|
||||
|
||||
cdef int first_state(self, State* state) except -1:
|
||||
|
|
|
@ -38,8 +38,6 @@ cdef class Tokens:
|
|||
cdef list _py_tokens
|
||||
cdef unicode _string
|
||||
cdef tuple _tag_strings
|
||||
cdef tuple _dep_strings
|
||||
cdef public tuple _ent_strings
|
||||
|
||||
cdef public bint is_tagged
|
||||
cdef public bint is_parsed
|
||||
|
@ -71,13 +69,11 @@ cdef class Token:
|
|||
|
||||
|
||||
cdef Tokens _seq
|
||||
cdef tuple _tag_strings
|
||||
cdef tuple _dep_strings
|
||||
|
||||
@staticmethod
|
||||
cdef inline Token cinit(Vocab vocab, unicode string,
|
||||
const TokenC* token, int offset, int array_len,
|
||||
Tokens parent_seq, tuple tag_strings, tuple dep_strings):
|
||||
Tokens parent_seq, self._tag_strings):
|
||||
if offset < 0 or offset >= array_len:
|
||||
|
||||
msg = "Attempt to access token at %d, max length %d"
|
||||
|
@ -92,8 +88,6 @@ cdef class Token:
|
|||
self.array_len = array_len
|
||||
|
||||
self._seq = parent_seq
|
||||
self._tag_strings = tag_strings
|
||||
self._dep_strings = dep_strings
|
||||
self._seq._py_tokens[offset] = self
|
||||
return self
|
||||
|
||||
|
|
|
@ -93,8 +93,6 @@ cdef class Tokens:
|
|||
self.is_parsed = False
|
||||
self._py_tokens = []
|
||||
self._tag_strings = tuple() # These will be set by the POS tagger and parser
|
||||
self._dep_strings = tuple() # The strings are arbitrary and model-specific.
|
||||
self._ent_strings = tuple() # TODO: Clean this up
|
||||
|
||||
def __getitem__(self, object i):
|
||||
"""Retrieve a token.
|
||||
|
@ -110,7 +108,7 @@ cdef class Tokens:
|
|||
bounds_check(i, self.length, PADDING)
|
||||
return Token.cinit(self.vocab, self._string,
|
||||
&self.data[i], i, self.length,
|
||||
self, self._tag_strings, self._dep_strings)
|
||||
self, self._tag_strings)
|
||||
|
||||
def __iter__(self):
|
||||
"""Iterate over the tokens.
|
||||
|
@ -121,7 +119,7 @@ cdef class Tokens:
|
|||
for i in range(self.length):
|
||||
yield Token.cinit(self.vocab, self._string,
|
||||
&self.data[i], i, self.length,
|
||||
self, self._tag_strings, self._dep_strings)
|
||||
self, self._tag_strings)
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
@ -148,7 +146,7 @@ cdef class Tokens:
|
|||
label = None
|
||||
elif token.ent_iob == 3:
|
||||
start = i
|
||||
label = self._ent_strings[token.ent_type]
|
||||
label = self.vocab.strings[token.ent_type]
|
||||
if start != -1:
|
||||
yield (start, self.length, label)
|
||||
|
||||
|
@ -252,10 +250,6 @@ cdef class Tokens:
|
|||
self.is_parsed = True
|
||||
for i in range(self.length):
|
||||
self.data[i] = parsed[i]
|
||||
dep_strings = [None] * len(label_ids)
|
||||
for dep_string, id_ in label_ids.items():
|
||||
dep_strings[id_] = dep_string
|
||||
self._dep_strings = tuple(dep_strings)
|
||||
|
||||
|
||||
cdef class Span:
|
||||
|
@ -265,6 +259,15 @@ cdef class Span:
|
|||
self.start = start
|
||||
self.end = end
|
||||
|
||||
def __richcmp__(self, Span other, int op):
|
||||
# Eq
|
||||
if op in (1, 2, 5):
|
||||
if self._seq is other._seq and \
|
||||
self.start == other.start and \
|
||||
self.end == other.end:
|
||||
return True
|
||||
return False
|
||||
|
||||
def __len__(self):
|
||||
if self.end < self.start:
|
||||
return 0
|
||||
|
@ -310,7 +313,7 @@ cdef class Token:
|
|||
def nbor(self, int i=1):
|
||||
return Token.cinit(self.vocab, self._string,
|
||||
self.c, self.i, self.array_len,
|
||||
self._seq, self._tag_strings, self._dep_strings)
|
||||
self._seq, self._tag_strings)
|
||||
|
||||
property string:
|
||||
def __get__(self):
|
||||
|
@ -411,7 +414,7 @@ cdef class Token:
|
|||
elif ptr + ptr.head == self.c:
|
||||
yield Token.cinit(self.vocab, self._string,
|
||||
ptr, ptr - (self.c - self.i), self.array_len,
|
||||
self._seq, self._tag_strings, self._dep_strings)
|
||||
self._seq, self._tag_strings)
|
||||
ptr += 1
|
||||
else:
|
||||
ptr += 1
|
||||
|
@ -430,7 +433,7 @@ cdef class Token:
|
|||
elif ptr + ptr.head == self.c:
|
||||
yield Token.cinit(self.vocab, self._string,
|
||||
ptr, ptr - (self.c - self.i), self.array_len,
|
||||
self._seq, self._tag_strings, self._dep_strings)
|
||||
self._seq, self._tag_strings)
|
||||
ptr -= 1
|
||||
else:
|
||||
ptr -= 1
|
||||
|
@ -453,7 +456,7 @@ cdef class Token:
|
|||
"""The token predicted by the parser to be the head of the current token."""
|
||||
return Token.cinit(self.vocab, self._string,
|
||||
self.c + self.c.head, self.i + self.c.head, self.array_len,
|
||||
self._seq, self._tag_strings, self._dep_strings)
|
||||
self._seq, self._tag_strings)
|
||||
|
||||
property whitespace_:
|
||||
def __get__(self):
|
||||
|
@ -497,7 +500,7 @@ cdef class Token:
|
|||
|
||||
property dep_:
|
||||
def __get__(self):
|
||||
return self._dep_strings[self.c.dep]
|
||||
return self.vocab.strings[self.c.dep]
|
||||
|
||||
|
||||
_pos_id_to_string = {id_: string for string, id_ in UNIV_POS_NAMES.items()}
|
||||
|
|
Loading…
Reference in New Issue
Block a user