mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-13 10:46:29 +03:00
* Work on refactored parser, where TransitionSystem can be easily subclassed
This commit is contained in:
parent
1cc6329b18
commit
dc986dbc0b
|
@ -53,9 +53,9 @@ cdef unicode print_state(State* s, list words):
|
|||
def get_templates(name):
|
||||
pf = _parse_features
|
||||
if name == 'zhang':
|
||||
return pf.unigrams, pf.arc_eager
|
||||
return pf.arc_eager
|
||||
else:
|
||||
return pf.unigrams, (pf.unigrams + pf.s0_n0 + pf.s1_n0 + pf.s0_n1 + pf.n0_n1 + \
|
||||
return (pf.unigrams + pf.s0_n0 + pf.s1_n0 + pf.s0_n1 + pf.n0_n1 + \
|
||||
pf.tree_shape + pf.trigrams)
|
||||
|
||||
|
||||
|
@ -64,84 +64,51 @@ cdef class GreedyParser:
|
|||
assert os.path.exists(model_dir) and os.path.isdir(model_dir)
|
||||
self.cfg = Config.read(model_dir, 'config')
|
||||
self.moves = TransitionSystem(self.cfg.left_labels, self.cfg.right_labels)
|
||||
hasty_templ, full_templ = get_templates(self.cfg.features)
|
||||
self.model = Model(self.moves.n_moves, full_templ, model_dir)
|
||||
templates = get_templates(self.cfg.features)
|
||||
self.model = Model(self.moves.n_moves, templates, model_dir)
|
||||
|
||||
def __call__(self, Tokens tokens):
|
||||
cdef:
|
||||
Transition guess
|
||||
uint64_t state_key
|
||||
|
||||
if tokens.length == 0:
|
||||
return 0
|
||||
|
||||
cdef atom_t[CONTEXT_SIZE] context
|
||||
cdef int n_feats
|
||||
cdef Pool mem = Pool()
|
||||
cdef State* state = init_state(mem, tokens.data, tokens.length)
|
||||
cdef State* state = init_state(mem, tokens.data, tokens.length) # TODO
|
||||
cdef Transition guess
|
||||
while not is_final(state):
|
||||
fill_context(context, state)
|
||||
scores = self.model.score(context)
|
||||
guess = self.moves.best_valid(scores, state)
|
||||
guess.do(&guess, state)
|
||||
# Messily tell Tokens object the string names of the dependency labels
|
||||
dep_strings = [None] * len(self.moves.label_ids)
|
||||
for label, id_ in self.moves.label_ids.items():
|
||||
dep_strings[id_] = label
|
||||
tokens._dep_strings = tuple(dep_strings)
|
||||
tokens.is_parsed = True
|
||||
# TODO: Clean this up.
|
||||
tokens._py_tokens = [None] * tokens.length
|
||||
tokens.set_parse(state.sent, self.moves.label_ids) # TODO
|
||||
return 0
|
||||
|
||||
def train_sent(self, Tokens tokens, list gold_heads, list gold_labels,
|
||||
force_gold=False):
|
||||
def train_sent(self, Tokens tokens, GoldParse gold, force_gold=False):
|
||||
cdef:
|
||||
int n_feats
|
||||
const Feature* feats
|
||||
const weight_t* scores
|
||||
Transition guess
|
||||
Transition gold
|
||||
|
||||
cdef int n_feats
|
||||
cdef atom_t[CONTEXT_SIZE] context
|
||||
cdef Pool mem = Pool()
|
||||
cdef int* heads_array = <int*>mem.alloc(tokens.length, sizeof(int))
|
||||
cdef int* labels_array = <int*>mem.alloc(tokens.length, sizeof(int))
|
||||
cdef int i
|
||||
for i in range(tokens.length):
|
||||
if gold_heads[i] is None:
|
||||
heads_array[i] = -1
|
||||
labels_array[i] = -1
|
||||
else:
|
||||
heads_array[i] = gold_heads[i]
|
||||
labels_array[i] = self.moves.label_ids[gold_labels[i]]
|
||||
atom_t[CONTEXT_SIZE] context
|
||||
|
||||
py_words = [t.orth_ for t in tokens]
|
||||
py_moves = ['S', 'D', 'L', 'R', 'BS', 'BR']
|
||||
history = []
|
||||
#print py_words
|
||||
cdef Pool mem = Pool()
|
||||
cdef State* state = init_state(mem, tokens.data, tokens.length)
|
||||
|
||||
while not is_final(state):
|
||||
fill_context(context, state)
|
||||
scores = self.model.score(context)
|
||||
guess = self.moves.best_valid(scores, state)
|
||||
best = self.moves.best_gold(&guess, scores, state, heads_array, labels_array)
|
||||
history.append((py_moves[best.move], print_state(state, py_words)))
|
||||
self.model.update(context, guess.clas, best.clas, guess.cost)
|
||||
gold = self.moves.best_gold(scores, state, gold)
|
||||
cost = guess.get_cost(&guess, state, gold)
|
||||
self.model.update(context, guess.clas, best.clas, cost)
|
||||
if force_gold:
|
||||
best.do(&best, state)
|
||||
else:
|
||||
guess.do(&guess, state)
|
||||
cdef int n_corr = 0
|
||||
for i in range(tokens.length):
|
||||
if gold_heads[i] != -1:
|
||||
n_corr += (i + state.sent[i].head) == gold_heads[i]
|
||||
n_corr = gold.heads_correct(state.sent, score_punct=True) # TODO
|
||||
if force_gold and n_corr != tokens.length:
|
||||
#print py_words
|
||||
#print gold_heads
|
||||
#for move, state_str in history:
|
||||
# print move, state_str
|
||||
#for i in range(tokens.length):
|
||||
# print py_words[i], py_words[i + state.sent[i].head], py_words[gold_heads[i]]
|
||||
raise OracleError
|
||||
return n_corr
|
||||
|
|
Loading…
Reference in New Issue
Block a user