* Work on integrating a greedy dependency parser

This commit is contained in:
Matthew Honnibal 2014-12-16 08:06:04 +11:00
parent 24ffc32f2f
commit 9e00798820
8 changed files with 535 additions and 0 deletions

View File

0
spacy/parser/__init__.py Normal file
View File

78
spacy/parser/_state.pxd Normal file
View File

@ -0,0 +1,78 @@
from cymem.cymem cimport Pool
cdef struct Subtree:
int[5] kids
int[5] labels
int length
cdef struct State:
double score
size_t i
size_t n
size_t stack_len
size_t top
size_t* stack
Subtree* lefts
Subtree* rights
int* heads
int* labels
cdef int add_dep(const State *s, size_t head, size_t child, size_t label) except -1
cdef size_t pop_stack(State *s) except 0
cdef int push_stack(State *s) except -1
cdef inline size_t get_s1(const State *s) nogil:
if s.stack_len < 2:
return 0
return s.stack[s.stack_len - 2]
cdef inline size_t get_l(const State *s, size_t head) nogil:
cdef const Subtree* subtree = &s.lefts[head]
if subtree.length == 0:
return 0
return subtree.kids[subtree.length - 1]
cdef inline size_t get_l2(const State *s, size_t head) nogil:
cdef const Subtree* subtree = &s.lefts[head]
if subtree.length < 2:
return 0
return subtree.kids[subtree.length - 2]
cdef inline size_t get_r(const State *s, size_t head) nogil:
cdef const Subtree* subtree = &s.rights[head]
if subtree.length == 0:
return 0
return subtree.kids[subtree.length - 1]
cdef inline size_t get_r2(const State *s, size_t head) nogil:
cdef const Subtree* subtree = &s.rights[head]
if subtree.length < 2:
return 0
return subtree.kids[subtree.length - 2]
cdef inline bint at_eol(const State *s) nogil:
return s.i >= (s.n - 1)
cdef inline bint is_final(State *s) nogil:
return at_eol(s) and s.stack_len == 0
cdef int has_child_in_buffer(State *s, size_t word, int* gold) except -1
cdef int has_head_in_buffer(State *s, size_t word, int* gold) except -1
cdef int has_child_in_stack(State *s, size_t word, int* gold) except -1
cdef int has_head_in_stack(State *s, size_t word, int* gold) except -1
cdef State* init_state(Pool mem, const int sent_length) except NULL

97
spacy/parser/_state.pyx Normal file
View File

@ -0,0 +1,97 @@
# cython: profile=True
from libc.string cimport memmove
from cymem.cymem cimport Pool
cdef int add_dep(State *s, size_t head, size_t child, size_t label) except -1:
s.heads[child] = head
s.labels[child] = label
cdef Subtree* subtree = &s.lefts[head] if child < head else &s.rights[head]
if subtree.length == 5:
memmove(subtree.kids, &subtree.kids[1], 4 * sizeof(int))
memmove(subtree.labels, &subtree.labels[1], 4 * sizeof(int))
subtree.kids[4] = child
subtree.labels[4] = label
else:
subtree.kids[subtree.length - 1] = child
subtree.kids[subtree.length - 1] = label
subtree.length += 1
cdef size_t pop_stack(State *s) except 0:
cdef size_t popped
assert s.stack_len >= 1
popped = s.top
s.top = get_s1(s)
s.stack_len -= 1
assert s.top <= s.n, s.top
assert popped != 0
return popped
cdef int push_stack(State *s) except -1:
s.top = s.i
s.stack[s.stack_len] = s.i
s.stack_len += 1
assert s.top <= s.n
s.i += 1
cdef int has_child_in_buffer(State *s, size_t word, int* gold_heads) except -1:
assert word != 0
cdef size_t buff_i
cdef int n = 0
for buff_i in range(s.i, s.n):
if s.heads[buff_i] == 0 and gold_heads[buff_i] == word:
n += 1
return n
cdef int has_head_in_buffer(State *s, size_t word, int* gold_heads) except -1:
assert word != 0
cdef size_t buff_i
for buff_i in range(s.i, s.n):
if s.heads[buff_i] == 0 and gold_heads[word] == buff_i:
return 1
return 0
cdef int has_child_in_stack(State *s, size_t word, int* gold_heads) except -1:
assert word != 0
cdef size_t i, stack_i
cdef int n = 0
for i in range(s.stack_len):
stack_i = s.stack[i]
# Should this be sensitive to whether the word has a head already?
if gold_heads[stack_i] == word:
n += 1
return n
cdef int has_head_in_stack(State *s, size_t word, int* gold_heads) except -1:
assert word != 0
cdef size_t i, stack_i
for i in range(s.stack_len):
stack_i = s.stack[i]
if gold_heads[word] == stack_i:
return 1
return 0
DEF PADDING = 5
cdef State* init_state(Pool mem, const int sent_length) except NULL:
cdef size_t i
cdef State* s = <State*>mem.alloc(1, sizeof(State))
s.n = sent_length
s.i = 1
s.top = 0
s.stack_len = 0
n = s.n + PADDING
s.stack = <size_t*>mem.alloc(n, sizeof(size_t))
s.heads = <int*>mem.alloc(n, sizeof(int))
s.labels = <int*>mem.alloc(n, sizeof(int))
s.lefts = <Subtree*>mem.alloc(n, sizeof(Subtree))
s.rights = <Subtree*>mem.alloc(n, sizeof(Subtree))
return s

View File

@ -0,0 +1,23 @@
from cymem.cymem cimport Pool
from thinc.typedefs cimport weight_t
from ._state cimport State
cdef struct Transition:
int move
int label
cdef class TransitionSystem:
cdef Pool mem
cdef readonly int n_moves
cdef const Transition* _moves
cdef int best_valid(self, const weight_t* scores, const State* s) except -1
cdef int best_gold(self, const weight_t* scores, const State* s,
list gold_heads, list gold_labels) except -1
cdef int transition(self, State *s, const int clas) except -1

162
spacy/parser/arc_eager.pyx Normal file
View File

@ -0,0 +1,162 @@
# cython: profile=True
from ._state cimport State
from ._state cimport at_eol, pop_stack, push_stack, add_dep
from ._state cimport has_head_in_buffer, has_child_in_buffer
from ._state cimport has_head_in_stack, has_child_in_stack
import index.hashes
cdef enum:
ERR
SHIFT
REDUCE
LEFT
RIGHT
N_MOVES
cdef inline bint _can_shift(const State* s) nogil:
return not at_eol(s)
cdef inline bint _can_right(const State* s) nogil:
return s.stack_len >= 1 and not at_eol(s)
cdef inline bint _can_left(const State* s) nogil:
return s.stack_len >= 1 and s.heads[s.top] == 0
cdef inline bint _can_reduce(const State* s) nogil:
return s.stack_len >= 2 and s.heads[s.top] != 0
cdef int _shift_cost(const State* s, list gold) except -1:
assert not at_eol(s)
cost = 0
cost += has_head_in_stack(s, s.i, gold)
cost += has_child_in_stack(s, s.i, gold)
return cost
cdef int _right_cost(const State* s, list gold) except -1:
assert s.stack_len >= 1
cost = 0
if gold[s.i] == s.top:
return cost
cost += has_head_in_buffer(s, s.i, gold)
cost += has_child_in_stack(s, s.i, gold)
cost += has_head_in_stack(s, s.i, gold)
return cost
cdef int _left_cost(const State* s, list gold) except -1:
assert s.stack_len >= 1
cost = 0
if gold[s.top] == s.i:
return cost
cost += has_head_in_buffer(s, s.top, gold)
cost += has_child_in_buffer(s, s.top, gold)
return cost
cdef int _reduce_cost(const State* s, list gold) except -1:
assert s.stack_len >= 2
cost = 0
cost += has_child_in_buffer(s, s.top, gold)
return cost
cdef class TransitionSystem:
def __init__(self, list left_labels, list right_labels):
self.mem = Pool()
self.n_moves = 2 + len(left_labels) + len(right_labels)
moves = <Transition*>self.mem.alloc(self.n_moves, sizeof(Transition))
cdef int i = 0
moves[i].move = SHIFT
moves[i].label = 0
i += 1
moves[i].move = REDUCE
moves[i].label = 0
i += 1
cdef int label
for label in left_labels:
moves[i].move = LEFT
moves[i].label = label
i += 1
for label in right_labels:
moves[i].move = RIGHT
moves[i].label = label
i += 1
self._moves = moves
cdef int transition(self, State *s, const int clas) except -1:
cdef const Transition* t = &self._moves[clas]
if t.move == SHIFT:
push_stack(s)
elif t.move == LEFT:
add_dep(s, s.i, s.top, t.label)
pop_stack(s)
elif t.move == RIGHT:
add_dep(s, s.top, s.i, t.label)
push_stack(s)
elif t.move == REDUCE:
pop_stack(s)
else:
raise StandardError(t.move)
cdef int best_valid(self, const weight_t* scores, const State* s) except -1:
cdef bint[N_MOVES] valid
valid[SHIFT] = _can_shift(s)
valid[LEFT] = _can_left(s)
valid[RIGHT] = _can_right(s)
valid[REDUCE] = _can_reduce(s)
cdef int best = -1
cdef weight_t score
cdef int i
for i in range(self.n_moves):
if valid[self._moves[i].move] and scores[i] > score:
best = i
score = scores[i]
assert best >= -1, "No valid moves found"
return best
cdef int best_gold(self, const weight_t* scores, const State* s,
list gold_heads, list gold_labels) except -1:
cdef int[N_MOVES] unl_costs
unl_costs[SHIFT] = _shift_cost(s, gold_heads) if _can_shift(s) else -1
unl_costs[LEFT] = _left_cost(s, gold_heads) if _can_left(s) else -1
unl_costs[RIGHT] = _right_cost(s, gold_heads) if _can_right(s) else -1
unl_costs[REDUCE] = _reduce_cost(s, gold_heads) if _can_reduce(s) else -1
cdef int cost
cdef int label_cost
cdef int move
cdef int label
cdef int best = -1
cdef weight_t score
cdef int i
for i in range(self.n_moves):
move = self._moves[i].move
label = self._moves[i].label
if unl_costs[move] == 0:
if move == SHIFT or move == REDUCE:
label_cost = 0
elif move == LEFT:
if gold_heads[s.top] == s.i:
label_cost = label != gold_labels[s.top]
else:
label_cost = 0
elif move == RIGHT:
if gold_heads[s.i] == s.top:
label_cost = label != gold_labels[s.i]
else:
label_cost = 0
else:
raise StandardError("Unknown Move")
if label_cost == 0 and scores[i] > score:
best = i
score = scores[i]
assert best >= -1, "No gold moves found"
return best

15
spacy/parser/parser.pxd Normal file
View File

@ -0,0 +1,15 @@
from thinc.features cimport Extractor
from thinc.learner cimport LinearModel
from .arc_eager cimport TransitionSystem
from ..tokens cimport Tokens, TokenC
cdef class Parser:
cdef object cfg
cdef Extractor extractor
cdef LinearModel model
cdef TransitionSystem moves
cpdef int parse(self, Tokens tokens) except -1

160
spacy/parser/parser.pyx Normal file
View File

@ -0,0 +1,160 @@
# cython: profile=True
"""
MALT-style dependency parser
"""
cimport cython
import random
import os.path
from os.path import join as pjoin
import shutil
import json
from cymem.cymem cimport Pool, Address
from thinc.typedefs cimport weight_t, class_t, feat_t, atom_t
from util import Config
from thinc.features cimport Extractor
from thinc.features cimport Feature
from thinc.features cimport count_feats
from thinc.learner cimport LinearModel
from ..tokens cimport Tokens, TokenC
from .arc_eager cimport TransitionSystem
from ._state cimport init_state, State, is_final, get_s1
VOCAB_SIZE = 1e6
TAG_SET_SIZE = 50
DEF CONTEXT_SIZE = 50
DEBUG = False
def set_debug(val):
global DEBUG
DEBUG = val
cdef str print_state(State* s, list words):
top = words[s.top]
second = words[get_s1(s)]
n0 = words[s.i]
n1 = words[s.i + 1]
return ' '.join((second, top, '|', n0, n1))
def train(sents, golds, model_dir, n_iter=15, feat_set=u'basic', seed=0):
if os.path.exists(model_dir):
shutil.rmtree(model_dir)
os.mkdir(model_dir)
left_labels, right_labels, dfl_labels = get_labels(golds)
Config.write(model_dir, 'config', features=feat_set, seed=seed,
left_labels=left_labels, right_labels=right_labels)
parser = Parser(model_dir)
indices = list(range(len(sents)))
for n in range(n_iter):
for i in indices:
parser.train_sent(sents[i], *golds[i])
#parser.tagger.train_sent(py_sent) # TODO
acc = float(parser.guide.n_corr) / parser.guide.total
print(parser.guide.end_train_iter(n) + '\t' +
parser.tagger.guide.end_train_iter(n))
random.shuffle(indices)
parser.guide.end_training()
parser.tagger.guide.end_training()
parser.guide.dump(pjoin(model_dir, 'model'), freq_thresh=0)
parser.tagger.guide.dump(pjoin(model_dir, 'tagger'))
return acc
def get_labels(sents):
'''Get alphabetically-sorted lists of left, right and disfluency labels that
occur in a sample of sentences. Used to determine the set of legal transitions
from the training set.
Args:
sentences (list[Input]): A list of Input objects, usually the training set.
Returns:
labels (tuple[list, list, list]): Sorted lists of left, right and disfluency
labels.
'''
left_labels = set()
right_labels = set()
# TODO
return list(sorted(left_labels)), list(sorted(right_labels))
def get_templates(feats_str):
'''Interpret feats_str, returning a list of template tuples. Each template
is a tuple of numeric indices, referring to positions in the context
array. See _parse_features.pyx for examples. The templates are applied by
thinc.features.Extractor, which picks out the appointed values and hashes
the resulting array, to produce a single feature code.
'''
return tuple()
cdef class Parser:
def __init__(self, model_dir):
assert os.path.exists(model_dir) and os.path.isdir(model_dir)
self.cfg = Config.read(model_dir, 'config')
self.extractor = Extractor(get_templates(self.cfg.features))
self.moves = TransitionSystem(self.cfg.left_labels, self.cfg.right_labels)
self.model = LinearModel(self.moves.n_moves, self.extractor.n_templ)
if os.path.exists(pjoin(model_dir, 'model')):
self.model.load(pjoin(model_dir, 'model'))
cpdef int parse(self, Tokens tokens) except -1:
cdef:
Feature* feats
weight_t* scores
cdef atom_t[CONTEXT_SIZE] context
cdef int n_feats
cdef Pool mem = Pool()
cdef State* state = init_state(mem, tokens.length)
while not is_final(state):
fill_context(context, state, tokens.data) # TODO
feats = self.extractor.get_feats(context, &n_feats)
scores = self.model.get_scores(feats, n_feats)
guess = self.moves.best_valid(scores, state)
self.moves.transition(state, guess)
# TODO output
def train_sent(self, Tokens tokens, list gold_heads, list gold_labels):
cdef:
Feature* feats
weight_t* scores
cdef int n_feats
cdef atom_t[CONTEXT_SIZE] context
cdef Pool mem = Pool()
cdef State* state = init_state(mem, tokens.length)
while not is_final(state):
fill_context(context, state, tokens.data) # TODO
feats = self.extractor.get_feats(context, &n_feats)
scores = self.model.get_scores(feats, n_feats)
guess = self.moves.best_valid(scores, state)
best = self.moves.best_gold(scores, state, gold_heads, gold_labels)
counts = {guess: {}, best: {}}
if guess != best:
count_feats(counts[guess], feats, n_feats, -1)
count_feats(counts[best], feats, n_feats, 1)
self.model.update(counts)
self.moves.transition(state, guess)
cdef int fill_context(atom_t* context, State* s, TokenC* sent) except -1:
pass