* Work on parser. Up to 92 UAS on YM labels

This commit is contained in:
Matthew Honnibal 2014-12-18 09:05:31 +11:00
parent 55de747bfc
commit 8446ebfbbb
8 changed files with 220 additions and 25 deletions

View File

@ -11,6 +11,7 @@ from .lexeme cimport Lexeme
from .tagger cimport Tagger from .tagger cimport Tagger
from .utf8string cimport StringStore, UniStr from .utf8string cimport StringStore, UniStr
from .morphology cimport Morphologizer from .morphology cimport Morphologizer
from .syntax.parser cimport GreedyParser
cdef union LexemesOrTokens: cdef union LexemesOrTokens:
@ -43,6 +44,7 @@ cdef class Language:
cpdef readonly Lexicon lexicon cpdef readonly Lexicon lexicon
cpdef readonly Tagger pos_tagger cpdef readonly Tagger pos_tagger
cpdef readonly Morphologizer morphologizer cpdef readonly Morphologizer morphologizer
cpdef readonly GreedyParser parser
cdef PreshMap _pos_cache cdef PreshMap _pos_cache
cdef object _prefix_re cdef object _prefix_re

View File

@ -44,15 +44,19 @@ cdef class Language:
self.pos_tagger = None self.pos_tagger = None
self.morphologizer = None self.morphologizer = None
def load(self, pos_dir=None): def load(self, pos_dir=None, parser_dir=None):
self.lexicon.load(path.join(util.DATA_DIR, self.name, 'lexemes')) self.lexicon.load(path.join(util.DATA_DIR, self.name, 'lexemes'))
self.lexicon.strings.load(path.join(util.DATA_DIR, self.name, 'strings')) self.lexicon.strings.load(path.join(util.DATA_DIR, self.name, 'strings'))
if pos_dir is None: if pos_dir is None:
pos_dir = path.join(util.DATA_DIR, self.name, 'pos') pos_dir = path.join(util.DATA_DIR, self.name, 'pos')
if parser_dir is None:
parser_dir = path.join(util.DATA_DIR, self.name, 'deps')
if path.exists(pos_dir): if path.exists(pos_dir):
self.pos_tagger = Tagger(pos_dir) self.pos_tagger = Tagger(pos_dir)
self.morphologizer = Morphologizer(self.lexicon.strings, pos_dir) self.morphologizer = Morphologizer(self.lexicon.strings, pos_dir)
#self.load_pos_cache(path.join(util.DATA_DIR, self.name, 'pos', 'bigram_cache_2m')) #self.load_pos_cache(path.join(util.DATA_DIR, self.name, 'pos', 'bigram_cache_2m'))
if path.exists(parser_dir):
self.parser = GreedyParser(parser_dir)
cpdef Tokens tokens_from_list(self, list strings): cpdef Tokens tokens_from_list(self, list strings):
cdef int length = sum([len(s) for s in strings]) cdef int length = sum([len(s) for s in strings])

View File

@ -22,6 +22,7 @@ cdef int fill_context(atom_t* context, State* state) except -1
# NB: The order of the enum is _NOT_ arbitrary!! # NB: The order of the enum is _NOT_ arbitrary!!
cpdef enum: cpdef enum:
S2w S2w
S2W
S2p S2p
S2c S2c
S2c4 S2c4
@ -29,6 +30,7 @@ cpdef enum:
S2L S2L
S1w S1w
S1W
S1p S1p
S1c S1c
S1c4 S1c4
@ -36,6 +38,7 @@ cpdef enum:
S1L S1L
S1rw S1rw
S1rW
S1rp S1rp
S1rc S1rc
S1rc4 S1rc4
@ -43,6 +46,7 @@ cpdef enum:
S1rL S1rL
S0lw S0lw
S0lW
S0lp S0lp
S0lc S0lc
S0lc4 S0lc4
@ -50,6 +54,7 @@ cpdef enum:
S0lL S0lL
S0l2w S0l2w
S0l2W
S0l2p S0l2p
S0l2c S0l2c
S0l2c4 S0l2c4
@ -57,6 +62,7 @@ cpdef enum:
S0l2L S0l2L
S0w S0w
S0W
S0p S0p
S0c S0c
S0c4 S0c4
@ -64,6 +70,7 @@ cpdef enum:
S0L S0L
S0r2w S0r2w
S0r2W
S0r2p S0r2p
S0r2c S0r2c
S0r2c4 S0r2c4
@ -71,6 +78,7 @@ cpdef enum:
S0r2L S0r2L
S0rw S0rw
S0rW
S0rp S0rp
S0rc S0rc
S0rc4 S0rc4
@ -78,6 +86,7 @@ cpdef enum:
S0rL S0rL
N0l2w N0l2w
N0l2W
N0l2p N0l2p
N0l2c N0l2c
N0l2c4 N0l2c4
@ -85,6 +94,7 @@ cpdef enum:
N0l2L N0l2L
N0lw N0lw
N0lW
N0lp N0lp
N0lc N0lc
N0lc4 N0lc4
@ -92,6 +102,7 @@ cpdef enum:
N0lL N0lL
N0w N0w
N0W
N0p N0p
N0c N0c
N0c4 N0c4
@ -99,6 +110,7 @@ cpdef enum:
N0L N0L
N1w N1w
N1W
N1p N1p
N1c N1c
N1c4 N1c4
@ -106,6 +118,7 @@ cpdef enum:
N1L N1L
N2w N2w
N2W
N2p N2p
N2c N2c
N2c4 N2c4
@ -119,5 +132,9 @@ cpdef enum:
S0rv S0rv
S1lv S1lv
S1rv S1rv
S0_has_head
S1_has_head
S2_has_head
CONTEXT_SIZE CONTEXT_SIZE

View File

@ -13,7 +13,8 @@ from itertools import combinations
from ..tokens cimport TokenC from ..tokens cimport TokenC
from ._state cimport State from ._state cimport State
from ._state cimport get_s2, get_s1, get_s0, get_n0, get_n1, get_n2 from ._state cimport get_s2, get_s1, get_s0, get_n0, get_n1, get_n2
from ._state cimport get_left, get_right from ._state cimport has_head, get_left, get_right
from ._state cimport count_left_kids, count_right_kids
cdef inline void fill_token(atom_t* context, const TokenC* token) nogil: cdef inline void fill_token(atom_t* context, const TokenC* token) nogil:
@ -24,10 +25,12 @@ cdef inline void fill_token(atom_t* context, const TokenC* token) nogil:
context[3] = 0 context[3] = 0
context[4] = 0 context[4] = 0
context[5] = 0 context[5] = 0
context[6] = 0
else: else:
context[0] = token.lex.sic context[0] = token.lex.sic
context[1] = token.pos context[1] = token.lemma
context[2] = token.lex.cluster context[2] = token.pos
context[3] = token.lex.cluster
# We've read in the string little-endian, so now we can take & (2**n)-1 # We've read in the string little-endian, so now we can take & (2**n)-1
# to get the first n bits of the cluster. # to get the first n bits of the cluster.
# e.g. s = "1110010101" # e.g. s = "1110010101"
@ -40,9 +43,9 @@ cdef inline void fill_token(atom_t* context, const TokenC* token) nogil:
# What we're doing here is picking a number where all bits are 1, e.g. # What we're doing here is picking a number where all bits are 1, e.g.
# 15 is 1111, 63 is 111111 and doing bitwise AND, so getting all bits in # 15 is 1111, 63 is 111111 and doing bitwise AND, so getting all bits in
# the source that are set to 1. # the source that are set to 1.
context[3] = token.lex.cluster & 63 context[4] = token.lex.cluster & 63
context[4] = token.lex.cluster & 15 context[5] = token.lex.cluster & 15
context[5] = token.dep_tag context[6] = token.dep_tag
cdef int fill_context(atom_t* context, State* state) except -1: cdef int fill_context(atom_t* context, State* state) except -1:
@ -66,12 +69,160 @@ cdef int fill_context(atom_t* context, State* state) except -1:
context[dist] = state.stack[0] - state.i context[dist] = state.stack[0] - state.i
else: else:
context[dist] = 0 context[dist] = 0
context[N0lv] = 0 context[N0lv] = max(count_left_kids(get_n0(state)), 5)
context[S0lv] = 0 context[S0lv] = max(count_left_kids(get_s0(state)), 5)
context[S0rv] = 0 context[S0rv] = max(count_right_kids(get_s0(state)), 5)
context[S1lv] = 0 context[S1lv] = max(count_left_kids(get_s1(state)), 5)
context[S1rv] = 0 context[S1rv] = max(count_right_kids(get_s1(state)), 5)
context[S0_has_head] = 0
context[S1_has_head] = 0
context[S2_has_head] = 0
if state.stack_len >= 1:
context[S0_has_head] = has_head(get_s0(state)) + 1
if state.stack_len >= 2:
context[S1_has_head] = has_head(get_s1(state)) + 1
if state.stack_len >= 3:
context[S2_has_head] = has_head(get_s2(state))
unigrams = (
(S2W, S2p),
(S2p,),
(S2c,),
(S2L,),
(S1W, S1p),
(S1p,),
(S1c,),
(S1L,),
(S0W, S0p),
(S0p,),
(S0c,),
(S0L,),
(N0W, N0p),
(N0p,),
(N0c,),
(N0L,),
(N1W, N1p),
(N1p,),
(N1c,),
(N2W, N2p),
(N2p,),
(N2c,),
(S0r2W, S0r2p),
(S0r2p,),
(S0r2c,),
(S0r2L,),
(S0rW, S0rp),
(S0rp,),
(S0rc,),
(S0rL,),
(S0l2W, S0l2p),
(S0l2p,),
(S0l2c,),
(S0l2L,),
(S0lW, S0lp),
(S0lp,),
(S0lc,),
(S0lL,),
(N0l2W, N0l2p),
(N0l2p,),
(N0l2c,),
(N0l2L,),
(N0lW, N0lp),
(N0lp,),
(N0lc,),
(N0lL,),
)
s0_n0 = (
(S0W, S0p, N0W, N0p),
(S0c, S0p, N0c, N0p),
(S0p, N0p),
(S0W, N0p),
(S0p, N0W),
(S0W, N0c),
(S0c, N0W),
(S0p, N0c),
(S0c, N0p),
(S0W, S0rp, N0p),
(S0p, S0rp, N0p),
(S0p, N0lp, N0W),
(S0p, N0lp, N0p),
)
s1_n0 = (
(S0_has_head, S1p, N0p),
(S0_has_head, S1c, N0c),
(S0_has_head, S1c, N0p),
(S0_has_head, S1p, N0c),
(S0_has_head, S1W, S1p, N0p),
(S0_has_head, S1p, N0W, N0p)
)
s0_n1 = (
(S0p, N1p),
(S0c, N1c),
(S0c, N1p),
(S0p, N1c),
(S0W, S0p, N1p),
(S0p, N1W, N1p)
)
n0_n1 = (
(N0W, N0p, N1W, N1p),
(N0W, N0p, N1p),
(N0p, N1W, N1p),
(N0c, N0p, N1c, N1p),
(N0c, N1c),
(N0p, N1c),
)
tree_shape = (
(dist,),
(S0p, S0_has_head, S1_has_head, S2_has_head),
(S0p, S0lv, S0rv),
(N0p, N0lv),
#(S0p, S0_left_shape),
#(S0p, S0_right_shape),
#(N0p, N0_left_shape),
#(S0p, S0_left_shape, N0_left_shape)
)
trigrams = (
(N0p, N1p, N2p),
(S0p, S0lp, S0l2p),
(S0p, S0rp, S0r2p),
(S0p, S1p, S2p),
(S1p, S0p, N0p),
(S0p, S0lp, N0p),
(S0p, N0p, N0lp),
(N0p, N0lp, N0l2p),
(S0W, S0p, S0rL, S0r2L),
(S0p, S0rL, S0r2L),
(S0W, S0p, S0lL, S0l2L),
(S0p, S0lL, S0l2L),
(N0W, N0p, N0lL, N0l2L),
(N0p, N0lL, N0l2L),
)
arc_eager = ( arc_eager = (
(S0w, S0p), (S0w, S0p),
@ -86,7 +237,6 @@ arc_eager = (
(N2w, N2p), (N2w, N2p),
(N2w,), (N2w,),
(N2p,), (N2p,),
(S0w, S0p, N0w, N0p), (S0w, S0p, N0w, N0p),
(S0w, S0p, N0w), (S0w, S0p, N0w),
(S0w, N0w, N0p), (S0w, N0w, N0p),

View File

@ -20,8 +20,7 @@ cdef int pop_stack(State *s) except -1
cdef int push_stack(State *s) except -1 cdef int push_stack(State *s) except -1
cdef inline bint has_head(const TokenC* t) nogil: cdef bint has_head(const TokenC* t) nogil
return t.head != 0
cdef inline int get_idx(const State* s, const TokenC* t) nogil: cdef inline int get_idx(const State* s, const TokenC* t) nogil:
@ -79,12 +78,10 @@ cdef int head_in_stack(const State *s, const int child, int* gold) except -1
cdef State* init_state(Pool mem, TokenC* sent, const int sent_length) except NULL cdef State* init_state(Pool mem, TokenC* sent, const int sent_length) except NULL
cdef inline int count_left_kids(const TokenC* head) nogil: cdef int count_left_kids(const TokenC* head) nogil
return _popcount(head.l_kids)
cdef inline int count_right_kids(const TokenC* head) nogil: cdef int count_right_kids(const TokenC* head) nogil
return _popcount(head.r_kids)
# From https://en.wikipedia.org/wiki/Hamming_weight # From https://en.wikipedia.org/wiki/Hamming_weight

View File

@ -3,6 +3,7 @@ from libc.string cimport memmove
from cymem.cymem cimport Pool from cymem.cymem cimport Pool
from ..lexeme cimport EMPTY_LEXEME from ..lexeme cimport EMPTY_LEXEME
from ..tokens cimport TokenC
cdef int add_dep(State *s, int head, int child, int label) except -1: cdef int add_dep(State *s, int head, int child, int label) except -1:
@ -88,6 +89,19 @@ cdef const TokenC* get_right(const State* s, const TokenC* head, const int idx)
return NULL return NULL
cdef bint has_head(const TokenC* t) nogil:
return t.head != 0
cdef int count_left_kids(const TokenC* head) nogil:
return _popcount(head.l_kids)
cdef int count_right_kids(const TokenC* head) nogil:
return _popcount(head.r_kids)
DEF PADDING = 5 DEF PADDING = 5

View File

@ -138,6 +138,16 @@ cdef class TransitionSystem:
unl_costs[RIGHT] = _right_cost(s, gold_heads) if _can_right(s) else -1 unl_costs[RIGHT] = _right_cost(s, gold_heads) if _can_right(s) else -1
unl_costs[REDUCE] = _reduce_cost(s, gold_heads) if _can_reduce(s) else -1 unl_costs[REDUCE] = _reduce_cost(s, gold_heads) if _can_reduce(s) else -1
#s0_buff_head = head_in_buffer(s, get_s0(s), gold_heads)
#s0_stack_head = head_in_stack(s, get_s0(s), gold_heads)
#s0_buff_kids = children_in_buffer(s, get_s0(s), gold_heads)
#s0_stack_kids = children_in_stack(s, get_s0(s), gold_heads)
#n0_buff_head = head_in_buffer(s, get_n0(s), gold_heads)
#n0_stack_head = head_in_stack(s, get_n0(s), gold_heads)
#n0_buff_kids = children_in_buffer(s, get_n0(s), gold_heads)
#n0_stack_kids = children_in_buffer(s, get_n0(s), gold_heads)
cdef int cost cdef int cost
cdef int move cdef int move
cdef int label cdef int label

View File

@ -49,7 +49,11 @@ cdef unicode print_state(State* s, list words):
def get_templates(name): def get_templates(name):
pf = _parse_features pf = _parse_features
return pf.arc_eager if name == 'zhang':
return pf.arc_eager
else:
templs = pf.unigrams + pf.s0_n0 + pf.s1_n0 + pf.s0_n1 + pf.n0_n1 + pf.tree_shape + pf.trigrams
return templs
cdef class GreedyParser: cdef class GreedyParser:
@ -58,7 +62,6 @@ cdef class GreedyParser:
self.cfg = Config.read(model_dir, 'config') self.cfg = Config.read(model_dir, 'config')
self.extractor = Extractor(get_templates(self.cfg.features)) self.extractor = Extractor(get_templates(self.cfg.features))
self.moves = TransitionSystem(self.cfg.left_labels, self.cfg.right_labels) self.moves = TransitionSystem(self.cfg.left_labels, self.cfg.right_labels)
self.model = LinearModel(self.moves.n_moves, self.extractor.n_templ) self.model = LinearModel(self.moves.n_moves, self.extractor.n_templ)
if os.path.exists(pjoin(model_dir, 'model')): if os.path.exists(pjoin(model_dir, 'model')):
self.model.load(pjoin(model_dir, 'model')) self.model.load(pjoin(model_dir, 'model'))
@ -77,9 +80,7 @@ cdef class GreedyParser:
fill_context(context, state) fill_context(context, state)
feats = self.extractor.get_feats(context, &n_feats) feats = self.extractor.get_feats(context, &n_feats)
scores = self.model.get_scores(feats, n_feats) scores = self.model.get_scores(feats, n_feats)
guess = self.moves.best_valid(scores, state) guess = self.moves.best_valid(scores, state)
self.moves.transition(state, guess) self.moves.transition(state, guess)
return 0 return 0
@ -110,7 +111,7 @@ cdef class GreedyParser:
self.moves.transition(state, guess) self.moves.transition(state, guess)
cdef int n_corr = 0 cdef int n_corr = 0
for i in range(tokens.length): for i in range(tokens.length):
n_corr += (i + state.sent[i].head) == heads_array[i] n_corr += (i + state.sent[i].head) == gold_heads[i]
return n_corr return n_corr