mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-13 18:56:36 +03:00
* Refactored transition system code now compiling. Still need to hook up label oracle, and test
This commit is contained in:
parent
6e86790a4e
commit
8c883cef58
|
@ -112,18 +112,17 @@ cdef int count_right_kids(const TokenC* head) nogil:
|
||||||
return _popcount(head.r_kids)
|
return _popcount(head.r_kids)
|
||||||
|
|
||||||
|
|
||||||
|
cdef State* init_state(Pool mem, const TokenC* sent, const int sent_len) except NULL:
|
||||||
cdef State* init_state(Pool mem, TokenC* sent, const int sent_length) except NULL:
|
cdef int padded_len = sent_len + PADDING + PADDING
|
||||||
cdef int padded_len = sent_length + PADDING + PADDING
|
|
||||||
cdef State* s = <State*>mem.alloc(1, sizeof(State))
|
cdef State* s = <State*>mem.alloc(1, sizeof(State))
|
||||||
s.stack = <int*>mem.alloc(padded_len, sizeof(int))
|
s.stack = <int*>mem.alloc(padded_len, sizeof(int))
|
||||||
for i in range(PADDING):
|
for i in range(PADDING):
|
||||||
s.stack[i] = -1
|
s.stack[i] = -1
|
||||||
s.stack += (PADDING - 1)
|
s.stack += (PADDING - 1)
|
||||||
assert s.stack[0] == -1
|
assert s.stack[0] == -1
|
||||||
s.sent = sent
|
s.sent = <TokenC*>mem.alloc(sent_len, sizeof(TokenC))
|
||||||
s.stack_len = 0
|
s.stack_len = 0
|
||||||
s.i = 0
|
s.i = 0
|
||||||
s.sent_len = sent_length
|
s.sent_len = sent_len
|
||||||
push_stack(s)
|
push_stack(s)
|
||||||
return s
|
return s
|
||||||
|
|
|
@ -33,14 +33,28 @@ cdef get_cost_func_t[N_MOVES] get_cost_funcs
|
||||||
|
|
||||||
|
|
||||||
cdef class ArcEager(TransitionSystem):
|
cdef class ArcEager(TransitionSystem):
|
||||||
|
@classmethod
|
||||||
|
def get_labels(cls, gold_parses):
|
||||||
|
labels = {RIGHT: {}, LEFT: {}}
|
||||||
|
for parse in gold_parses:
|
||||||
|
for i, (head, label) in enumerate(zip(parse.heads, parse.labels)):
|
||||||
|
if head > i:
|
||||||
|
labels[RIGHT][label] = True
|
||||||
|
else:
|
||||||
|
labels[LEFT][label] = True
|
||||||
|
return labels
|
||||||
|
|
||||||
cdef Transition init_transition(self, int clas, int move, int label) except *:
|
cdef Transition init_transition(self, int clas, int move, int label) except *:
|
||||||
return Transition(
|
# TODO: Apparent Cython bug here when we try to use the Transition()
|
||||||
score=0,
|
# constructor with the function pointers
|
||||||
clas=i,
|
cdef Transition t
|
||||||
move=move,
|
t.score = 0
|
||||||
label=label,
|
t.clas = clas
|
||||||
do=do_funcs[move],
|
t.move = move
|
||||||
get_cost=get_cost_funcs[move])
|
t.label = label
|
||||||
|
t.do = do_funcs[move]
|
||||||
|
t.get_cost = get_cost_funcs[move]
|
||||||
|
return t
|
||||||
|
|
||||||
cdef Transition best_valid(self, const weight_t* scores, const State* s) except *:
|
cdef Transition best_valid(self, const weight_t* scores, const State* s) except *:
|
||||||
cdef bint[N_MOVES] is_valid
|
cdef bint[N_MOVES] is_valid
|
||||||
|
@ -111,8 +125,8 @@ do_funcs[BREAK] = _do_break
|
||||||
cdef int _shift_cost(const Transition* self, const State* s, GoldParse gold) except -1:
|
cdef int _shift_cost(const Transition* self, const State* s, GoldParse gold) except -1:
|
||||||
assert not at_eol(s)
|
assert not at_eol(s)
|
||||||
cost = 0
|
cost = 0
|
||||||
cost += head_in_stack(s, s.i, gold.heads)
|
cost += head_in_stack(s, s.i, gold.c_heads)
|
||||||
cost += children_in_stack(s, s.i, gold.heads)
|
cost += children_in_stack(s, s.i, gold.c_heads)
|
||||||
if NON_MONOTONIC:
|
if NON_MONOTONIC:
|
||||||
cost += gold[s.stack[0]] == s.i
|
cost += gold[s.stack[0]] == s.i
|
||||||
# If we can break, and there's no cost to doing so, we should
|
# If we can break, and there's no cost to doing so, we should
|
||||||
|
@ -126,9 +140,9 @@ cdef int _right_cost(const Transition* self, const State* s, GoldParse gold) exc
|
||||||
cost = 0
|
cost = 0
|
||||||
if gold[s.i] == s.stack[0]:
|
if gold[s.i] == s.stack[0]:
|
||||||
return cost
|
return cost
|
||||||
cost += head_in_buffer(s, s.i, gold.heads)
|
cost += head_in_buffer(s, s.i, gold.c_heads)
|
||||||
cost += children_in_stack(s, s.i, gold.heads)
|
cost += children_in_stack(s, s.i, gold.c_heads)
|
||||||
cost += head_in_stack(s, s.i, gold.heads)
|
cost += head_in_stack(s, s.i, gold.c_heads)
|
||||||
if NON_MONOTONIC:
|
if NON_MONOTONIC:
|
||||||
cost += gold[s.stack[0]] == s.i
|
cost += gold[s.stack[0]] == s.i
|
||||||
return cost
|
return cost
|
||||||
|
@ -140,8 +154,8 @@ cdef int _left_cost(const Transition* self, const State* s, GoldParse gold) exce
|
||||||
if gold[s.stack[0]] == s.i:
|
if gold[s.stack[0]] == s.i:
|
||||||
return cost
|
return cost
|
||||||
|
|
||||||
cost += head_in_buffer(s, s.stack[0], gold.heads)
|
cost += head_in_buffer(s, s.stack[0], gold.c_heads)
|
||||||
cost += children_in_buffer(s, s.stack[0], gold.heads)
|
cost += children_in_buffer(s, s.stack[0], gold.c_heads)
|
||||||
if NON_MONOTONIC and s.stack_len >= 2:
|
if NON_MONOTONIC and s.stack_len >= 2:
|
||||||
cost += gold[s.stack[0]] == s.stack[-1]
|
cost += gold[s.stack[0]] == s.stack[-1]
|
||||||
cost += gold[s.stack[0]] == s.stack[0]
|
cost += gold[s.stack[0]] == s.stack[0]
|
||||||
|
@ -150,9 +164,9 @@ cdef int _left_cost(const Transition* self, const State* s, GoldParse gold) exce
|
||||||
|
|
||||||
cdef int _reduce_cost(const Transition* self, const State* s, GoldParse gold) except -1:
|
cdef int _reduce_cost(const Transition* self, const State* s, GoldParse gold) except -1:
|
||||||
cdef int cost = 0
|
cdef int cost = 0
|
||||||
cost += children_in_buffer(s, s.stack[0], gold.heads)
|
cost += children_in_buffer(s, s.stack[0], gold.c_heads)
|
||||||
if NON_MONOTONIC:
|
if NON_MONOTONIC:
|
||||||
cost += head_in_buffer(s, s.stack[0], gold.heads)
|
cost += head_in_buffer(s, s.stack[0], gold.c_heads)
|
||||||
return cost
|
return cost
|
||||||
|
|
||||||
|
|
||||||
|
@ -161,8 +175,8 @@ cdef int _break_cost(const Transition* self, const State* s, GoldParse gold) exc
|
||||||
cdef int cost = 0
|
cdef int cost = 0
|
||||||
# Number of deps between S0...Sn and N0...Nn
|
# Number of deps between S0...Sn and N0...Nn
|
||||||
for i in range(s.i, s.sent_len):
|
for i in range(s.i, s.sent_len):
|
||||||
cost += children_in_stack(s, i, gold.heads)
|
cost += children_in_stack(s, i, gold.c_heads)
|
||||||
cost += head_in_stack(s, i, gold.heads)
|
cost += head_in_stack(s, i, gold.c_heads)
|
||||||
return cost
|
return cost
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,22 @@
|
||||||
|
from cymem.cymem cimport Pool
|
||||||
|
|
||||||
from ..structs cimport TokenC
|
from ..structs cimport TokenC
|
||||||
|
|
||||||
|
|
||||||
cdef class GoldParse:
|
cdef class GoldParse:
|
||||||
cdef int* heads
|
cdef Pool mem
|
||||||
cdef int* labels
|
|
||||||
|
cdef int* c_heads
|
||||||
|
cdef int* c_labels
|
||||||
|
|
||||||
|
cdef int length
|
||||||
|
cdef int loss
|
||||||
|
|
||||||
|
cdef unicode raw_text
|
||||||
|
cdef list words
|
||||||
|
cdef list ids
|
||||||
|
cdef list tags
|
||||||
|
cdef list heads
|
||||||
|
cdef list labels
|
||||||
|
|
||||||
|
|
||||||
cdef int heads_correct(self, TokenC* tokens, bint score_punct=?) except -1
|
cdef int heads_correct(self, TokenC* tokens, bint score_punct=?) except -1
|
||||||
|
|
|
@ -1,11 +1,20 @@
|
||||||
cdef class GoldParse:
|
cdef class GoldParse:
|
||||||
def __init__(self):
|
def __init__(self, raw_text, words, ids, tags, heads, labels):
|
||||||
pass
|
self.mem = Pool()
|
||||||
|
self.loss = 0
|
||||||
|
self.length = len(words)
|
||||||
|
self.raw_text = raw_text
|
||||||
|
self.words = words
|
||||||
|
self.ids = ids
|
||||||
|
self.tags = tags
|
||||||
|
self.heads = heads
|
||||||
|
self.labels = labels
|
||||||
|
self.c_heads = <int*>self.mem.alloc(self.length, sizeof(int))
|
||||||
|
self.c_labels = <int*>self.mem.alloc(self.length, sizeof(int))
|
||||||
|
|
||||||
cdef int heads_correct(self, TokenC* tokens, bint score_punct=False) except -1:
|
cdef int heads_correct(self, TokenC* tokens, bint score_punct=False) except -1:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
"""
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_conll(cls, unicode sent_str):
|
def from_conll(cls, unicode sent_str):
|
||||||
ids = []
|
ids = []
|
||||||
|
@ -50,42 +59,44 @@ cdef class GoldParse:
|
||||||
for sent_str in tok_text.split('<SENT>')]
|
for sent_str in tok_text.split('<SENT>')]
|
||||||
return cls(raw_text, tokenized, ids, words, tags, heads, labels)
|
return cls(raw_text, tokenized, ids, words, tags, heads, labels)
|
||||||
|
|
||||||
cdef int heads_correct(self, TokenC* tokens, bint score_punct=False) except -1:
|
def align_to_tokens(self, tokens, label_ids):
|
||||||
pass
|
orig_words = list(self.words)
|
||||||
|
annot = zip(self.ids, self.tags, self.heads, self.labels)
|
||||||
def align_to_non_gold_tokens(self, tokens):
|
self.ids = []
|
||||||
# TODO
|
self.tags = []
|
||||||
tags = []
|
self.heads = []
|
||||||
heads = []
|
self.labels = []
|
||||||
labels = []
|
|
||||||
orig_words = list(words)
|
|
||||||
missed = []
|
missed = []
|
||||||
for token in tokens:
|
for token in tokens:
|
||||||
while annot and token.idx > annot[0][0]:
|
while annot and token.idx > annot[0][0]:
|
||||||
miss_id, miss_tag, miss_head, miss_label = annot.pop(0)
|
miss_id, miss_tag, miss_head, miss_label = annot.pop(0)
|
||||||
miss_w = words.pop(0)
|
miss_w = self.words.pop(0)
|
||||||
if not is_punct_label(miss_label):
|
if not is_punct_label(miss_label):
|
||||||
missed.append(miss_w)
|
missed.append(miss_w)
|
||||||
loss += 1
|
self.loss += 1
|
||||||
if not annot:
|
if not annot:
|
||||||
tags.append(None)
|
self.tags.append(None)
|
||||||
heads.append(None)
|
self.heads.append(None)
|
||||||
labels.append(None)
|
self.labels.append(None)
|
||||||
continue
|
continue
|
||||||
id_, tag, head, label = annot[0]
|
id_, tag, head, label = annot[0]
|
||||||
if token.idx == id_:
|
if token.idx == id_:
|
||||||
tags.append(tag)
|
self.tags.append(tag)
|
||||||
heads.append(head)
|
self.heads.append(head)
|
||||||
labels.append(label)
|
self.labels.append(label)
|
||||||
annot.pop(0)
|
annot.pop(0)
|
||||||
words.pop(0)
|
self.words.pop(0)
|
||||||
elif token.idx < id_:
|
elif token.idx < id_:
|
||||||
tags.append(None)
|
self.tags.append(None)
|
||||||
heads.append(None)
|
self.heads.append(None)
|
||||||
labels.append(None)
|
self.labels.append(None)
|
||||||
else:
|
else:
|
||||||
raise StandardError
|
raise StandardError
|
||||||
return loss, tags, heads, labels
|
mapped_heads = _map_indices_to_tokens(self.ids, self.heads)
|
||||||
|
for i in range(self.length):
|
||||||
|
self.c_heads[i] = mapped_heads[i]
|
||||||
|
self.c_labels[i] = label_ids[self.labels[i]]
|
||||||
|
return self.loss
|
||||||
|
|
||||||
|
|
||||||
def is_punct_label(label):
|
def is_punct_label(label):
|
||||||
|
@ -116,6 +127,7 @@ def _parse_line(line):
|
||||||
return id_, word, pos, head_idx, label
|
return id_, word, pos, head_idx, label
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
# TODO
|
# TODO
|
||||||
def evaluate(Language, dev_loc, model_dir, gold_preproc=False):
|
def evaluate(Language, dev_loc, model_dir, gold_preproc=False):
|
||||||
global loss
|
global loss
|
||||||
|
|
|
@ -26,9 +26,10 @@ from thinc.learner cimport LinearModel
|
||||||
from ..tokens cimport Tokens, TokenC
|
from ..tokens cimport Tokens, TokenC
|
||||||
|
|
||||||
from .arc_eager cimport TransitionSystem, Transition
|
from .arc_eager cimport TransitionSystem, Transition
|
||||||
from .arc_eager import OracleError
|
from .transition_system import OracleError
|
||||||
|
|
||||||
from ._state cimport init_state, State, is_final, get_idx, get_s0, get_s1, get_n0, get_n1
|
from ._state cimport init_state, State, is_final, get_idx, get_s0, get_s1, get_n0, get_n1
|
||||||
|
from .conll cimport GoldParse
|
||||||
|
|
||||||
from . import _parse_features
|
from . import _parse_features
|
||||||
from ._parse_features cimport fill_context, CONTEXT_SIZE
|
from ._parse_features cimport fill_context, CONTEXT_SIZE
|
||||||
|
@ -60,10 +61,10 @@ def get_templates(name):
|
||||||
|
|
||||||
|
|
||||||
cdef class GreedyParser:
|
cdef class GreedyParser:
|
||||||
def __init__(self, model_dir):
|
def __init__(self, model_dir, transition_system):
|
||||||
assert os.path.exists(model_dir) and os.path.isdir(model_dir)
|
assert os.path.exists(model_dir) and os.path.isdir(model_dir)
|
||||||
self.cfg = Config.read(model_dir, 'config')
|
self.cfg = Config.read(model_dir, 'config')
|
||||||
self.moves = TransitionSystem(self.cfg.left_labels, self.cfg.right_labels)
|
self.moves = transition_system(self.cfg.labels)
|
||||||
templates = get_templates(self.cfg.features)
|
templates = get_templates(self.cfg.features)
|
||||||
self.model = Model(self.moves.n_moves, templates, model_dir)
|
self.model = Model(self.moves.n_moves, templates, model_dir)
|
||||||
|
|
||||||
|
@ -74,23 +75,24 @@ cdef class GreedyParser:
|
||||||
cdef atom_t[CONTEXT_SIZE] context
|
cdef atom_t[CONTEXT_SIZE] context
|
||||||
cdef int n_feats
|
cdef int n_feats
|
||||||
cdef Pool mem = Pool()
|
cdef Pool mem = Pool()
|
||||||
cdef State* state = init_state(mem, tokens.data, tokens.length) # TODO
|
cdef State* state = init_state(mem, tokens.data, tokens.length)
|
||||||
cdef Transition guess
|
cdef Transition guess
|
||||||
while not is_final(state):
|
while not is_final(state):
|
||||||
fill_context(context, state)
|
fill_context(context, state)
|
||||||
scores = self.model.score(context)
|
scores = self.model.score(context)
|
||||||
guess = self.moves.best_valid(scores, state)
|
guess = self.moves.best_valid(scores, state)
|
||||||
guess.do(&guess, state)
|
guess.do(&guess, state)
|
||||||
tokens.set_parse(state.sent, self.moves.label_ids) # TODO
|
tokens.set_parse(state.sent, self.moves.label_ids)
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
def train_sent(self, Tokens tokens, GoldParse gold, force_gold=False):
|
def train(self, Tokens tokens, GoldParse gold, force_gold=False):
|
||||||
cdef:
|
cdef:
|
||||||
int n_feats
|
int n_feats
|
||||||
|
int cost
|
||||||
const Feature* feats
|
const Feature* feats
|
||||||
const weight_t* scores
|
const weight_t* scores
|
||||||
Transition guess
|
Transition guess
|
||||||
Transition gold
|
Transition best
|
||||||
|
|
||||||
atom_t[CONTEXT_SIZE] context
|
atom_t[CONTEXT_SIZE] context
|
||||||
|
|
||||||
|
@ -101,7 +103,7 @@ cdef class GreedyParser:
|
||||||
fill_context(context, state)
|
fill_context(context, state)
|
||||||
scores = self.model.score(context)
|
scores = self.model.score(context)
|
||||||
guess = self.moves.best_valid(scores, state)
|
guess = self.moves.best_valid(scores, state)
|
||||||
gold = self.moves.best_gold(scores, state, gold)
|
best = self.moves.best_gold(scores, state, gold)
|
||||||
cost = guess.get_cost(&guess, state, gold)
|
cost = guess.get_cost(&guess, state, gold)
|
||||||
self.model.update(context, guess.clas, best.clas, cost)
|
self.model.update(context, guess.clas, best.clas, cost)
|
||||||
if force_gold:
|
if force_gold:
|
||||||
|
|
|
@ -30,7 +30,7 @@ cdef class TransitionSystem:
|
||||||
|
|
||||||
cdef Transition init_transition(self, int clas, int move, int label) except *
|
cdef Transition init_transition(self, int clas, int move, int label) except *
|
||||||
|
|
||||||
cdef const Transition best_valid(self, const weight_t*, const State*) except *
|
cdef Transition best_valid(self, const weight_t* scores, const State* state) except *
|
||||||
|
|
||||||
cdef const Transition best_gold(self, const weight_t*, const State*,
|
cdef Transition best_gold(self, const weight_t* scores, const State* state,
|
||||||
GoldParse gold) except *
|
GoldParse gold) except *
|
||||||
|
|
|
@ -7,6 +7,10 @@ from thinc.typedefs cimport weight_t
|
||||||
cdef weight_t MIN_SCORE = -90000
|
cdef weight_t MIN_SCORE = -90000
|
||||||
|
|
||||||
|
|
||||||
|
class OracleError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
cdef class TransitionSystem:
|
cdef class TransitionSystem:
|
||||||
def __init__(self, dict labels_by_action):
|
def __init__(self, dict labels_by_action):
|
||||||
self.mem = Pool()
|
self.mem = Pool()
|
||||||
|
@ -28,7 +32,7 @@ cdef class TransitionSystem:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
cdef Transition best_gold(self, const weight_t* scores, const State* s,
|
cdef Transition best_gold(self, const weight_t* scores, const State* s,
|
||||||
const TokenC* gold) except *:
|
GoldParse gold) except *:
|
||||||
cdef Transition best
|
cdef Transition best
|
||||||
cdef weight_t score = MIN_SCORE
|
cdef weight_t score = MIN_SCORE
|
||||||
cdef int i
|
cdef int i
|
||||||
|
|
Loading…
Reference in New Issue
Block a user