* Add a parse_batch method to Parser, that releases the GIL around a batch of documents.

This commit is contained in:
Matthew Honnibal 2016-02-01 08:34:55 +01:00
parent 80caba28c7
commit bcf8f7ba40
7 changed files with 89 additions and 48 deletions

View File

@ -1,9 +1,10 @@
from thinc.typedefs cimport atom_t from thinc.typedefs cimport atom_t
from .stateclass cimport StateClass from .stateclass cimport StateClass
from ._state cimport StateC
cdef int fill_context(atom_t* context, StateClass state) nogil cdef int fill_context(atom_t* context, const StateC* state) nogil
# Context elements # Context elements
# Ensure each token's attributes are listed: w, p, c, c6, c4. The order # Ensure each token's attributes are listed: w, p, c, c6, c4. The order

View File

@ -14,6 +14,7 @@ from itertools import combinations
from ..structs cimport TokenC from ..structs cimport TokenC
from .stateclass cimport StateClass from .stateclass cimport StateClass
from ._state cimport StateC
from cymem.cymem cimport Pool from cymem.cymem cimport Pool
@ -59,7 +60,7 @@ cdef inline void fill_token(atom_t* context, const TokenC* token) nogil:
context[10] = token.ent_iob context[10] = token.ent_iob
context[11] = token.ent_type context[11] = token.ent_type
cdef int fill_context(atom_t* ctxt, StateClass st) nogil: cdef int fill_context(atom_t* ctxt, const StateC* st) nogil:
# Take care to fill every element of context! # Take care to fill every element of context!
# We could memset, but this makes it very easy to have broken features that # We could memset, but this makes it very easy to have broken features that
# make almost no impact on accuracy. If instead they're unset, the impact # make almost no impact on accuracy. If instead they're unset, the impact

View File

@ -377,23 +377,23 @@ cdef class ArcEager(TransitionSystem):
raise Exception(move) raise Exception(move)
return t return t
cdef int initialize_state(self, StateClass st) except -1: cdef int initialize_state(self, StateC* st) nogil:
# Ensure sent_start is set to 0 throughout # Ensure sent_start is set to 0 throughout
for i in range(st.c.length): for i in range(st.length):
st.c._sent[i].sent_start = False st._sent[i].sent_start = False
st.c._sent[i].l_edge = i st._sent[i].l_edge = i
st.c._sent[i].r_edge = i st._sent[i].r_edge = i
st.fast_forward() st.fast_forward()
cdef int finalize_state(self, StateClass st) nogil: cdef int finalize_state(self, StateC* st) nogil:
cdef int i cdef int i
for i in range(st.c.length): for i in range(st.length):
if st.c._sent[i].head == 0 and st.c._sent[i].dep == 0: if st._sent[i].head == 0 and st._sent[i].dep == 0:
st.c._sent[i].dep = self.root_label st._sent[i].dep = self.root_label
# If we're not using the Break transition, we segment via root-labelled # If we're not using the Break transition, we segment via root-labelled
# arcs between the root words. # arcs between the root words.
elif USE_ROOT_ARC_SEGMENT and st.c._sent[i].dep == self.root_label: elif USE_ROOT_ARC_SEGMENT and st._sent[i].dep == self.root_label:
st.c._sent[i].head = 0 st._sent[i].head = 0
cdef int set_valid(self, int* output, const StateC* st) nogil: cdef int set_valid(self, int* output, const StateC* st) nogil:
cdef bint[N_MOVES] is_valid cdef bint[N_MOVES] is_valid

View File

@ -6,14 +6,15 @@ from .stateclass cimport StateClass
from .arc_eager cimport TransitionSystem from .arc_eager cimport TransitionSystem
from ..tokens.doc cimport Doc from ..tokens.doc cimport Doc
from ..structs cimport TokenC from ..structs cimport TokenC
from ._state cimport StateC
cdef class ParserModel(AveragedPerceptron): cdef class ParserModel(AveragedPerceptron):
cdef void set_featuresC(self, ExampleC* eg, StateClass stcls) nogil cdef void set_featuresC(self, ExampleC* eg, const StateC* state) nogil
cdef class Parser: cdef class Parser:
cdef readonly ParserModel model cdef readonly ParserModel model
cdef readonly TransitionSystem moves cdef readonly TransitionSystem moves
cdef void parseC(self, Doc tokens, StateClass stcls, Example eg) nogil cdef void parseC(self, TokenC* tokens, int length, int nr_feat, int nr_class) nogil

View File

@ -1,3 +1,4 @@
# cython: infer_types=True
""" """
MALT-style dependency parser MALT-style dependency parser
""" """
@ -9,6 +10,7 @@ from cpython.exc cimport PyErr_CheckSignals
from libc.stdint cimport uint32_t, uint64_t from libc.stdint cimport uint32_t, uint64_t
from libc.string cimport memset, memcpy from libc.string cimport memset, memcpy
from libc.stdlib cimport malloc, calloc, free
import random import random
import os.path import os.path
from os import path from os import path
@ -21,6 +23,7 @@ from murmurhash.mrmr cimport hash64
from thinc.typedefs cimport weight_t, class_t, feat_t, atom_t, hash_t from thinc.typedefs cimport weight_t, class_t, feat_t, atom_t, hash_t
from thinc.linear.avgtron cimport AveragedPerceptron from thinc.linear.avgtron cimport AveragedPerceptron
from thinc.linalg cimport VecVec from thinc.linalg cimport VecVec
from thinc.structs cimport FeatureC
from util import Config from util import Config
@ -38,6 +41,7 @@ from . import _parse_features
from ._parse_features cimport CONTEXT_SIZE from ._parse_features cimport CONTEXT_SIZE
from ._parse_features cimport fill_context from ._parse_features cimport fill_context
from .stateclass cimport StateClass from .stateclass cimport StateClass
from ._state cimport StateC
@ -65,8 +69,8 @@ def ParserFactory(transition_system):
cdef class ParserModel(AveragedPerceptron): cdef class ParserModel(AveragedPerceptron):
cdef void set_featuresC(self, ExampleC* eg, StateClass stcls) nogil: cdef void set_featuresC(self, ExampleC* eg, const StateC* state) nogil:
fill_context(eg.atoms, stcls) fill_context(eg.atoms, state)
eg.nr_feat = self.extracter.set_features(eg.features, eg.atoms) eg.nr_feat = self.extracter.set_features(eg.features, eg.atoms)
@ -99,43 +103,77 @@ cdef class Parser:
return (Parser, (self.moves.strings, self.moves, self.model), None, None) return (Parser, (self.moves.strings, self.moves, self.model), None, None)
def __call__(self, Doc tokens): def __call__(self, Doc tokens):
cdef StateClass stcls = StateClass.init(tokens.c, tokens.length) cdef int nr_class = self.moves.n_moves
self.moves.initialize_state(stcls) cdef int nr_feat = self.model.nr_feat
cdef Example eg = Example(
nr_class=self.moves.n_moves,
nr_atom=CONTEXT_SIZE,
nr_feat=self.model.nr_feat)
with nogil: with nogil:
self.parseC(tokens, stcls, eg) self.parseC(tokens.c, tokens.length, nr_feat, nr_class)
tokens.is_parsed = True
# Check for KeyboardInterrupt etc. Untested # Check for KeyboardInterrupt etc. Untested
PyErr_CheckSignals() PyErr_CheckSignals()
cdef void parseC(self, Doc tokens, StateClass stcls, Example eg) nogil: def parse_batch(self, batch):
while not stcls.is_final(): cdef TokenC** doc_ptr = <TokenC**>calloc(len(batch), sizeof(TokenC*))
self.model.set_featuresC(&eg.c, stcls) cdef int* lengths = <int*>calloc(len(batch), sizeof(int))
self.moves.set_valid(eg.c.is_valid, stcls.c) cdef Doc doc
self.model.set_scoresC(eg.c.scores, eg.c.features, eg.c.nr_feat) cdef int i
for i, doc in enumerate(batch):
doc_ptr[i] = doc.c
lengths[i] = doc.length
cdef int nr_class = self.moves.n_moves
cdef int nr_feat = self.model.nr_feat
cdef int nr_doc = len(batch)
with nogil:
for i in range(nr_doc):
self.parseC(doc_ptr[i], lengths[i], nr_feat, nr_class)
for doc in batch:
doc.is_parsed = True
# Check for KeyboardInterrupt etc. Untested
PyErr_CheckSignals()
free(doc_ptr)
free(lengths)
guess = VecVec.arg_max_if_true(eg.c.scores, eg.c.is_valid, eg.c.nr_class) cdef void parseC(self, TokenC* tokens, int length, int nr_feat, int nr_class) nogil:
cdef ExampleC eg
eg.nr_feat = nr_feat
eg.nr_atom = CONTEXT_SIZE
eg.nr_class = nr_class
eg.features = <FeatureC*>calloc(sizeof(FeatureC), nr_feat)
eg.atoms = <atom_t*>calloc(sizeof(atom_t), CONTEXT_SIZE)
eg.scores = <weight_t*>calloc(sizeof(weight_t), nr_class)
eg.is_valid = <int*>calloc(sizeof(int), nr_class)
state = new StateC(tokens, length)
self.moves.initialize_state(state)
cdef int i
while not state.is_final():
self.model.set_featuresC(&eg, state)
self.moves.set_valid(eg.is_valid, state)
self.model.set_scoresC(eg.scores, eg.features, eg.nr_feat)
guess = VecVec.arg_max_if_true(eg.scores, eg.is_valid, eg.nr_class)
action = self.moves.c[guess] action = self.moves.c[guess]
if not eg.c.is_valid[guess]: if not eg.is_valid[guess]:
with gil: with gil:
move_name = self.moves.move_name(action.move, action.label) move_name = self.moves.move_name(action.move, action.label)
raise ValueError("Illegal action: %s" % move_name) raise ValueError("Illegal action: %s" % move_name)
action.do(stcls.c, action.label) action.do(state, action.label)
memset(eg.c.scores, 0, sizeof(eg.c.scores[0]) * eg.c.nr_class) memset(eg.scores, 0, sizeof(eg.scores[0]) * eg.nr_class)
memset(eg.c.costs, 0, sizeof(eg.c.costs[0]) * eg.c.nr_class) memset(eg.costs, 0, sizeof(eg.costs[0]) * eg.nr_class)
for i in range(eg.c.nr_class): for i in range(eg.nr_class):
eg.c.is_valid[i] = 1 eg.is_valid[i] = 1
self.moves.finalize_state(stcls) self.moves.finalize_state(state)
tokens.set_parse(stcls.c._sent) for i in range(length):
tokens[i] = state._sent[i]
del state
free(eg.features)
free(eg.atoms)
free(eg.scores)
free(eg.is_valid)
def train(self, Doc tokens, GoldParse gold): def train(self, Doc tokens, GoldParse gold):
self.moves.preprocess_gold(gold) self.moves.preprocess_gold(gold)
cdef StateClass stcls = StateClass.init(tokens.c, tokens.length) cdef StateClass stcls = StateClass.init(tokens.c, tokens.length)
self.moves.initialize_state(stcls) self.moves.initialize_state(stcls.c)
cdef Pool mem = Pool() cdef Pool mem = Pool()
cdef Example eg = Example( cdef Example eg = Example(
nr_class=self.moves.n_moves, nr_class=self.moves.n_moves,
@ -144,7 +182,7 @@ cdef class Parser:
cdef weight_t loss = 0 cdef weight_t loss = 0
cdef Transition action cdef Transition action
while not stcls.is_final(): while not stcls.is_final():
self.model.set_featuresC(&eg.c, stcls) self.model.set_featuresC(&eg.c, stcls.c)
self.moves.set_costs(eg.c.is_valid, eg.c.costs, stcls, gold) self.moves.set_costs(eg.c.is_valid, eg.c.costs, stcls, gold)
self.model.set_scoresC(eg.c.scores, eg.c.features, eg.c.nr_feat) self.model.set_scoresC(eg.c.scores, eg.c.features, eg.c.nr_feat)
self.model.updateC(&eg.c) self.model.updateC(&eg.c)
@ -174,7 +212,7 @@ cdef class StepwiseState:
self.parser = parser self.parser = parser
self.doc = doc self.doc = doc
self.stcls = StateClass.init(doc.c, doc.length) self.stcls = StateClass.init(doc.c, doc.length)
self.parser.moves.initialize_state(self.stcls) self.parser.moves.initialize_state(self.stcls.c)
self.eg = Example( self.eg = Example(
nr_class=self.parser.moves.n_moves, nr_class=self.parser.moves.n_moves,
nr_atom=CONTEXT_SIZE, nr_atom=CONTEXT_SIZE,
@ -209,7 +247,7 @@ cdef class StepwiseState:
def predict(self): def predict(self):
self.eg.reset() self.eg.reset()
self.parser.model.set_featuresC(&self.eg.c, self.stcls) self.parser.model.set_featuresC(&self.eg.c, self.stcls.c)
self.parser.moves.set_valid(self.eg.c.is_valid, self.stcls.c) self.parser.moves.set_valid(self.eg.c.is_valid, self.stcls.c)
self.parser.model.set_scoresC(self.eg.c.scores, self.parser.model.set_scoresC(self.eg.c.scores,
self.eg.c.features, self.eg.c.nr_feat) self.eg.c.features, self.eg.c.nr_feat)
@ -234,7 +272,7 @@ cdef class StepwiseState:
def finish(self): def finish(self):
if self.stcls.is_final(): if self.stcls.is_final():
self.parser.moves.finalize_state(self.stcls) self.parser.moves.finalize_state(self.stcls.c)
self.doc.set_parse(self.stcls.c._sent) self.doc.set_parse(self.stcls.c._sent)

View File

@ -38,8 +38,8 @@ cdef class TransitionSystem:
cdef public int root_label cdef public int root_label
cdef public freqs cdef public freqs
cdef int initialize_state(self, StateClass state) except -1 cdef int initialize_state(self, StateC* state) nogil
cdef int finalize_state(self, StateClass state) nogil cdef int finalize_state(self, StateC* state) nogil
cdef int preprocess_gold(self, GoldParse gold) except -1 cdef int preprocess_gold(self, GoldParse gold) except -1

View File

@ -47,10 +47,10 @@ cdef class TransitionSystem:
(self.strings, labels_by_action, self.freqs), (self.strings, labels_by_action, self.freqs),
None, None) None, None)
cdef int initialize_state(self, StateClass state) except -1: cdef int initialize_state(self, StateC* state) nogil:
pass pass
cdef int finalize_state(self, StateClass state) nogil: cdef int finalize_state(self, StateC* state) nogil:
pass pass
cdef int preprocess_gold(self, GoldParse gold) except -1: cdef int preprocess_gold(self, GoldParse gold) except -1: