* Add beam search capabilities to Parser. Rename GreedyParser to Parser.

This commit is contained in:
Matthew Honnibal 2015-06-02 00:28:02 +02:00
parent 62424e6c76
commit 58d5ac0944
3 changed files with 125 additions and 19 deletions

View File

@ -5,7 +5,7 @@ import re
from .. import orth from .. import orth
from ..vocab import Vocab from ..vocab import Vocab
from ..tokenizer import Tokenizer from ..tokenizer import Tokenizer
from ..syntax.parser import GreedyParser from ..syntax.parser import Parser
from ..syntax.arc_eager import ArcEager from ..syntax.arc_eager import ArcEager
from ..syntax.ner import BiluoPushDown from ..syntax.ner import BiluoPushDown
from ..tokens import Tokens from ..tokens import Tokens
@ -112,7 +112,7 @@ class English(object):
@property @property
def parser(self): def parser(self):
if self._parser is None: if self._parser is None:
self._parser = GreedyParser(self.vocab.strings, self._parser = Parser(self.vocab.strings,
path.join(self._data_dir, 'deps'), path.join(self._data_dir, 'deps'),
self.ParserTransitionSystem) self.ParserTransitionSystem)
return self._parser return self._parser
@ -120,7 +120,7 @@ class English(object):
@property @property
def entity(self): def entity(self):
if self._entity is None: if self._entity is None:
self._entity = GreedyParser(self.vocab.strings, self._entity = Parser(self.vocab.strings,
path.join(self._data_dir, 'ner'), path.join(self._data_dir, 'ner'),
self.EntityTransitionSystem) self.EntityTransitionSystem)
return self._entity return self._entity

View File

@ -1,11 +1,19 @@
from thinc.search cimport Beam
from .._ml cimport Model from .._ml cimport Model
from .arc_eager cimport TransitionSystem from .arc_eager cimport TransitionSystem
from ..tokens cimport Tokens, TokenC from ..tokens cimport Tokens, TokenC
from ._state cimport State
cdef class GreedyParser: cdef class GreedyParser:
cdef readonly object cfg cdef readonly object cfg
cdef readonly Model model cdef readonly Model model
cdef readonly TransitionSystem moves cdef readonly TransitionSystem moves
cdef State* _greedy_parse(self, Tokens tokens) except NULL
cdef State* _beam_parse(self, Tokens tokens) except NULL

View File

@ -23,13 +23,16 @@ from thinc.features cimport count_feats
from thinc.learner cimport LinearModel from thinc.learner cimport LinearModel
from thinc.search cimport Beam
from thinc.search cimport MaxViolation
from ..tokens cimport Tokens, TokenC from ..tokens cimport Tokens, TokenC
from ..strings cimport StringStore from ..strings cimport StringStore
from .arc_eager cimport TransitionSystem, Transition from .arc_eager cimport TransitionSystem, Transition
from .transition_system import OracleError from .transition_system import OracleError
from ._state cimport new_state, State, is_final, get_idx, get_s0, get_s1, get_n0, get_n1 from ._state cimport State, new_state, copy_state, is_final, push_stack
from ..gold cimport GoldParse from ..gold cimport GoldParse
from . import _parse_features from . import _parse_features
@ -67,7 +70,7 @@ def get_templates(name):
pf.tree_shape + pf.trigrams) pf.tree_shape + pf.trigrams)
cdef class GreedyParser: cdef class Parser:
def __init__(self, StringStore strings, model_dir, transition_system): def __init__(self, StringStore strings, model_dir, transition_system):
assert os.path.exists(model_dir) and os.path.isdir(model_dir) assert os.path.exists(model_dir) and os.path.isdir(model_dir)
self.cfg = Config.read(model_dir, 'config') self.cfg = Config.read(model_dir, 'config')
@ -78,7 +81,15 @@ cdef class GreedyParser:
def __call__(self, Tokens tokens): def __call__(self, Tokens tokens):
if tokens.length == 0: if tokens.length == 0:
return 0 return 0
cdef State* state
if self.cfg.beam_width == 1:
state = self._greedy_parse(tokens)
else:
state = self._beam_parse(tokens)
self.moves.finalize_state(state)
tokens.set_parse(state.sent)
cdef State* _greedy_parse(self, Tokens tokens) except NULL:
cdef atom_t[CONTEXT_SIZE] context cdef atom_t[CONTEXT_SIZE] context
cdef int n_feats cdef int n_feats
cdef Pool mem = Pool() cdef Pool mem = Pool()
@ -87,16 +98,26 @@ cdef class GreedyParser:
cdef Transition guess cdef Transition guess
while not is_final(state): while not is_final(state):
fill_context(context, state) fill_context(context, state)
scores = self.model.score(context, False) scores = self.model.score(context)
guess = self.moves.best_valid(scores, state) guess = self.moves.best_valid(scores, state)
guess.do(&guess, state) guess.do(&guess, state)
self.moves.finalize_state(state) return state
tokens.set_parse(state.sent)
return 0 cdef State* _beam_parse(self, Tokens tokens) except NULL:
cdef Beam beam = Beam(self.model.n_classes, self.cfg.beam_width)
beam.initialize(_init_state, tokens.length, tokens.data)
while not beam.is_done:
self._advance_beam(beam, None, False)
return <State*>beam.at(0)
def train(self, Tokens tokens, GoldParse gold): def train(self, Tokens tokens, GoldParse gold):
py_words = [w.orth_ for w in tokens]
self.moves.preprocess_gold(gold) self.moves.preprocess_gold(gold)
if self.beam_width == 1:
return self._greedy_train(tokens, gold)
else:
return self._beam_train(tokens, gold)
def _greedy_train(self, Tokens tokens, GoldParse gold):
cdef Pool mem = Pool() cdef Pool mem = Pool()
cdef State* state = new_state(mem, tokens.data, tokens.length) cdef State* state = new_state(mem, tokens.data, tokens.length)
self.moves.initialize_state(state) self.moves.initialize_state(state)
@ -109,16 +130,93 @@ cdef class GreedyParser:
cdef atom_t[CONTEXT_SIZE] context cdef atom_t[CONTEXT_SIZE] context
loss = 0 loss = 0
while not is_final(state): while not is_final(state):
fill_context(context, state) fill_context(context, state)
scores = self.model.score(context, True) scores = self.model.score(context)
guess = self.moves.best_valid(scores, state) guess = self.moves.best_valid(scores, state)
best = self.moves.best_gold(scores, state, gold) best = self.moves.best_gold(scores, state, gold)
cost = guess.get_cost(&guess, state, gold) cost = guess.get_cost(&guess, state, gold)
self.model.update(context, guess.clas, best.clas, cost) self.model.update(context, guess.clas, best.clas, cost)
guess.do(&guess, state) guess.do(&guess, state)
loss += cost loss += cost
self.moves.finalize_state(state)
return loss return loss
def _beam_train(self, Tokens tokens, GoldParse gold_parse):
cdef Beam pred = Beam(self.model.n_classes, self.cfg.beam_width)
pred.initialize(_init_state, tokens.length, tokens.data)
cdef Beam gold = Beam(self.model.n_classes, self.cfg.beam_width)
gold.initialize(_init_state, tokens.length, tokens.data)
violn = MaxViolation()
while not pred.is_done and not gold.is_done:
self._advance_beam(pred, gold_parse, False)
self._advance_beam(gold, gold_parse, True)
violn.check(pred, gold)
counts = {}
if pred.loss >= 1:
self._count_feats(counts, tokens, violn.g_hist, 1)
self._count_feats(counts, tokens, violn.p_hist, -1)
self.model._model.update(counts)
return pred.loss
def _advance_beam(self, Beam beam, GoldParse gold, bint follow_gold):
cdef atom_t[CONTEXT_SIZE] context
cdef State* state
cdef int i, j, cost
cdef bint is_valid
cdef const Transition* move
for i in range(beam.size):
state = <State*>beam.at(i)
fill_context(context, state)
scores = self.model.score(context)
validities = self.moves.get_valid(state)
if gold is None:
for j in range(self.model.n_clases):
beam.set_cell(i, j, scores[j], 0, validities[j])
elif not follow_gold:
for j in range(self.model.n_classes):
move = &self.moves.c[j]
cost = move.get_cost(move, state, gold)
beam.set_cell(i, j, scores[j], cost, validities[j])
else:
for j in range(self.model.n_classes):
move = &self.moves.c[j]
cost = move.get_cost(move, state, gold)
beam.set_cell(i, j, scores[j], cost, cost == 0)
beam.advance(_transition_state, <void*>self.moves.c)
beam.check_done(_check_final_state, NULL)
def _count_feats(self, dict counts, Tokens tokens, list hist, int inc):
cdef atom_t[CONTEXT_SIZE] context
cdef Pool mem = Pool()
cdef State* state = new_state(mem, tokens.data, tokens.length)
self.moves.initialize_state(state)
cdef class_t clas
cdef int n_feats
for clas in hist:
if is_final(state):
break
fill_context(context, state)
feats = self.model._extractor.get_feats(context, &n_feats)
count_feats(counts.setdefault(clas, {}), feats, n_feats, inc)
self.moves.c[clas].do(&self.moves.c[clas], state)
# These are passed as callbacks to thinc.search.Beam
cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1:
dest = <State*>_dest
src = <const State*>_src
moves = <const Transition*>_moves
copy_state(dest, src)
moves[clas].do(&moves[clas], dest)
cdef void* _init_state(Pool mem, int length, void* tokens) except NULL:
state = new_state(mem, <const TokenC*>tokens, length)
push_stack(state)
return state
cdef int _check_final_state(void* state, void* extra_args) except -1:
return is_final(<State*>state)