Merge pull request #1438 from explosion/feature/fast-parser

💫 Improve runtime CPU efficiency of parser/NER
This commit is contained in:
Matthew Honnibal 2017-10-19 02:42:21 +02:00 committed by GitHub
commit 61bc203f3f
6 changed files with 123 additions and 190 deletions

View File

@ -53,7 +53,8 @@ MOD_NAMES = [
COMPILE_OPTIONS = {
'msvc': ['/Ox', '/EHsc'],
'mingw32' : ['-O3', '-Wno-strict-prototypes', '-Wno-unused-function'],
'other' : ['-O3', '-Wno-strict-prototypes', '-Wno-unused-function']
'other' : ['-O3', '-Wno-strict-prototypes', '-Wno-unused-function',
'-march=native']
}

View File

@ -10,6 +10,7 @@ from collections import OrderedDict
import itertools
import weakref
import functools
import tqdm
from .tokenizer import Tokenizer
from .vocab import Vocab
@ -447,11 +448,9 @@ class Language(object):
golds = list(golds)
for name, pipe in self.pipeline:
if not hasattr(pipe, 'pipe'):
for doc in docs:
pipe(doc)
docs = (pipe(doc) for doc in docs)
else:
docs = list(pipe.pipe(docs))
assert len(docs) == len(golds)
docs = pipe.pipe(docs, batch_size=256)
for doc, gold in zip(docs, golds):
if verbose:
print(doc)

View File

@ -15,8 +15,6 @@ cdef class Parser:
cdef readonly object cfg
cdef public object _multitasks
cdef void _parse_step(self, StateC* state,
const float* feat_weights,
int nr_class, int nr_feat, int nr_piece) nogil
#cdef int parseC(self, TokenC* tokens, int length, int nr_feat) nogil
cdef void _parseC(self, StateC* state,
const float* feat_weights, const float* hW, const float* hb,
int nr_class, int nr_hidden, int nr_feat, int nr_piece) nogil

View File

@ -9,6 +9,7 @@ from collections import Counter, OrderedDict
import ujson
import json
import contextlib
import numpy
from libc.math cimport exp
cimport cython
@ -27,7 +28,7 @@ from libc.string cimport memset, memcpy
from libc.stdlib cimport malloc, calloc, free
from thinc.typedefs cimport weight_t, class_t, feat_t, atom_t, hash_t
from thinc.linear.avgtron cimport AveragedPerceptron
from thinc.linalg cimport VecVec
from thinc.linalg cimport Vec, VecVec
from thinc.structs cimport SparseArrayC, FeatureC, ExampleC
from thinc.extra.eg cimport Example
from thinc.extra.search cimport Beam
@ -37,7 +38,7 @@ from murmurhash.mrmr cimport hash64
from preshed.maps cimport MapStruct
from preshed.maps cimport map_get
from thinc.api import layerize, chain, noop, clone, with_flatten
from thinc.api import layerize, chain, clone, with_flatten
from thinc.v2v import Model, Maxout, Softmax, Affine, ReLu, SELU
from thinc.misc import LayerNorm
@ -240,54 +241,32 @@ cdef class Parser:
@classmethod
def Model(cls, nr_class, **cfg):
depth = util.env_opt('parser_hidden_depth', cfg.get('hidden_depth', 1))
if depth != 1:
raise ValueError("Currently parser depth is hard-coded to 1.")
parser_maxout_pieces = util.env_opt('parser_maxout_pieces', cfg.get('maxout_pieces', 2))
if parser_maxout_pieces != 2:
raise ValueError("Currently parser_maxout_pieces is hard-coded to 2")
token_vector_width = util.env_opt('token_vector_width', cfg.get('token_vector_width', 128))
hidden_width = util.env_opt('hidden_width', cfg.get('hidden_width', 200))
parser_maxout_pieces = util.env_opt('parser_maxout_pieces', cfg.get('maxout_pieces', 2))
embed_size = util.env_opt('embed_size', cfg.get('embed_size', 7000))
hist_size = util.env_opt('history_feats', cfg.get('hist_size', 0))
hist_width = util.env_opt('history_width', cfg.get('hist_width', 0))
if hist_size >= 1 and depth == 0:
raise ValueError("Inconsistent hyper-params: "
"history_feats >= 1 but parser_hidden_depth==0")
if hist_size != 0:
raise ValueError("Currently history size is hard-coded to 0")
if hist_width != 0:
raise ValueError("Currently history width is hard-coded to 0")
tok2vec = Tok2Vec(token_vector_width, embed_size,
pretrained_dims=cfg.get('pretrained_dims', 0))
tok2vec = chain(tok2vec, flatten)
if parser_maxout_pieces == 1:
lower = PrecomputableAffine(hidden_width if depth >= 1 else nr_class,
nF=cls.nr_feature,
nI=token_vector_width)
else:
lower = PrecomputableMaxouts(hidden_width if depth >= 1 else nr_class,
nF=cls.nr_feature,
nP=parser_maxout_pieces,
nF=cls.nr_feature, nP=parser_maxout_pieces,
nI=token_vector_width)
with Model.use_device('cpu'):
if depth == 0:
upper = chain()
upper.is_noop = True
elif hist_size and depth == 1:
upper = chain(
HistoryFeatures(nr_class=nr_class, hist_size=hist_size,
nr_dim=hist_width),
zero_init(Affine(nr_class, hidden_width+hist_size*hist_width,
drop_factor=0.0)))
upper.is_noop = False
elif hist_size:
upper = chain(
HistoryFeatures(nr_class=nr_class, hist_size=hist_size,
nr_dim=hist_width),
LayerNorm(Maxout(hidden_width, hidden_width+hist_size*hist_width)),
clone(LayerNorm(Maxout(hidden_width, hidden_width)), depth-2),
zero_init(Affine(nr_class, hidden_width, drop_factor=0.0))
)
upper.is_noop = False
else:
upper = chain(
clone(LayerNorm(Maxout(hidden_width, hidden_width)), depth-1),
zero_init(Affine(nr_class, hidden_width, drop_factor=0.0))
)
upper.is_noop = False
# TODO: This is an unfortunate hack atm!
# Used to set input dimensions in network.
@ -391,90 +370,100 @@ cdef class Parser:
beam_density = self.cfg.get('beam_density', 0.0)
cdef Doc doc
cdef Beam beam
for docs in cytoolz.partition_all(batch_size, docs):
docs = list(docs)
for batch in cytoolz.partition_all(batch_size, docs):
batch = list(batch)
by_length = sorted(list(batch), key=lambda doc: len(doc))
for subbatch in cytoolz.partition_all(8, by_length):
subbatch = list(subbatch)
if beam_width == 1:
parse_states = self.parse_batch(docs)
parse_states = self.parse_batch(subbatch)
beams = []
else:
beams = self.beam_parse(docs,
beams = self.beam_parse(subbatch,
beam_width=beam_width, beam_density=beam_density)
parse_states = []
for beam in beams:
parse_states.append(<StateClass>beam.at(0))
self.set_annotations(docs, parse_states)
yield from docs
self.set_annotations(subbatch, parse_states)
yield from batch
def parse_batch(self, docs):
cdef:
precompute_hiddens state2vec
StateClass state
StateClass stcls
Pool mem
const float* feat_weights
StateC* st
vector[StateC*] next_step, this_step
int nr_class, nr_feat, nr_piece, nr_dim, nr_state
vector[StateC*] states
int guess, nr_class, nr_feat, nr_piece, nr_dim, nr_state, nr_step
int j
if isinstance(docs, Doc):
docs = [docs]
cuda_stream = get_cuda_stream()
(tokvecs, bp_tokvecs), state2vec, vec2scores = self.get_batch_model(docs, cuda_stream,
0.0)
nr_state = len(docs)
nr_class = self.moves.n_moves
nr_dim = tokvecs.shape[1]
nr_feat = self.nr_feature
nr_piece = state2vec.nP
states = self.moves.init_batch(docs)
for state in states:
if not state.c.is_final():
next_step.push_back(state.c)
state_objs = self.moves.init_batch(docs)
for stcls in state_objs:
if not stcls.c.is_final():
states.push_back(stcls.c)
feat_weights = state2vec.get_feat_weights()
cdef int i
cdef np.ndarray token_ids = numpy.zeros((nr_state, nr_feat), dtype='i')
cdef np.ndarray is_valid = numpy.zeros((nr_state, nr_class), dtype='i')
cdef np.ndarray scores
c_token_ids = <int*>token_ids.data
c_is_valid = <int*>is_valid.data
cdef int has_hidden = not getattr(vec2scores, 'is_noop', False)
cdef int nr_step
while not next_step.empty():
nr_step = next_step.size()
if not has_hidden:
for i in cython.parallel.prange(nr_step, num_threads=6,
nogil=True):
self._parse_step(next_step[i],
feat_weights, nr_class, nr_feat, nr_piece)
else:
hists = []
for i in range(nr_step):
st = next_step[i]
st.set_context_tokens(&c_token_ids[i*nr_feat], nr_feat)
self.moves.set_valid(&c_is_valid[i*nr_class], st)
hists.append([st.get_hist(j+1) for j in range(8)])
hists = numpy.asarray(hists)
vectors = state2vec(token_ids[:next_step.size()])
if self.cfg.get('hist_size'):
scores = vec2scores((vectors, hists))
else:
scores = vec2scores(vectors)
c_scores = <float*>scores.data
for i in range(nr_step):
st = next_step[i]
guess = arg_max_if_valid(
&c_scores[i*nr_class], &c_is_valid[i*nr_class], nr_class)
cdef np.ndarray hidden_weights = numpy.ascontiguousarray(vec2scores._layers[-1].W.T)
cdef np.ndarray hidden_bias = vec2scores._layers[-1].b
hW = <float*>hidden_weights.data
hb = <float*>hidden_bias.data
cdef int nr_hidden = hidden_weights.shape[0]
cdef int nr_task = states.size()
with nogil:
for i in cython.parallel.prange(nr_task, num_threads=2,
schedule='guided'):
self._parseC(states[i],
feat_weights, hW, hb,
nr_class, nr_hidden, nr_feat, nr_piece)
return state_objs
cdef void _parseC(self, StateC* state,
const float* feat_weights, const float* hW, const float* hb,
int nr_class, int nr_hidden, int nr_feat, int nr_piece) nogil:
token_ids = <int*>calloc(nr_feat, sizeof(int))
is_valid = <int*>calloc(nr_class, sizeof(int))
vectors = <float*>calloc(nr_hidden * nr_piece, sizeof(float))
scores = <float*>calloc(nr_class, sizeof(float))
while not state.is_final():
state.set_context_tokens(token_ids, nr_feat)
memset(vectors, 0, nr_hidden * nr_piece * sizeof(float))
memset(scores, 0, nr_class * sizeof(float))
sum_state_features(vectors,
feat_weights, token_ids, 1, nr_feat, nr_hidden * nr_piece)
V = vectors
W = hW
for i in range(nr_hidden):
feature = V[0] if V[0] >= V[1] else V[1]
for j in range(nr_class):
scores[j] += feature * W[j]
W += nr_class
V += nr_piece
for i in range(nr_class):
scores[i] += hb[i]
self.moves.set_valid(is_valid, state)
guess = arg_max_if_valid(scores, is_valid, nr_class)
action = self.moves.c[guess]
action.do(st, action.label)
st.push_hist(guess)
this_step, next_step = next_step, this_step
next_step.clear()
for st in this_step:
if not st.is_final():
next_step.push_back(st)
return states
action.do(state, action.label)
state.push_hist(guess)
free(token_ids)
free(is_valid)
free(vectors)
free(scores)
def beam_parse(self, docs, int beam_width=3, float beam_density=0.001):
cdef Beam beam
@ -527,27 +516,6 @@ cdef class Parser:
beams.append(beam)
return beams
cdef void _parse_step(self, StateC* state,
const float* feat_weights,
int nr_class, int nr_feat, int nr_piece) nogil:
'''This only works with no hidden layers -- fast but inaccurate'''
token_ids = <int*>calloc(nr_feat, sizeof(int))
scores = <float*>calloc(nr_class * nr_piece, sizeof(float))
is_valid = <int*>calloc(nr_class, sizeof(int))
state.set_context_tokens(token_ids, nr_feat)
sum_state_features(scores,
feat_weights, token_ids, 1, nr_feat, nr_class * nr_piece)
self.moves.set_valid(is_valid, state)
guess = arg_maxout_if_valid(scores, is_valid, nr_class, nr_piece)
action = self.moves.c[guess]
action.do(state, action.label)
state.push_hist(guess)
free(is_valid)
free(scores)
free(token_ids)
def update(self, docs, golds, drop=0., sgd=None, losses=None):
if not any(self.moves.has_gold(gold) for gold in golds):
return None
@ -800,15 +768,6 @@ cdef class Parser:
if self.model not in (True, False, None) and resized:
# Weights are stored in (nr_out, nr_in) format, so we're basically
# just adding rows here.
if self.model[-1].is_noop:
smaller = self.model[1]
dims = dict(self.model[1]._dims)
dims['nO'] = self.moves.n_moves
larger = self.model[1].__class__(**dims)
copy_array(larger.W[:, :smaller.nO], smaller.W)
copy_array(larger.b[:smaller.nO], smaller.b)
self.model = (self.model[0], larger, self.model[2])
else:
smaller = self.model[-1]._layers[-1]
larger = Affine(self.moves.n_moves, smaller.nI)
copy_array(larger.W[:smaller.nO], smaller.W)
@ -969,31 +928,6 @@ cdef int arg_max_if_valid(const weight_t* scores, const int* is_valid, int n) no
return best
cdef int arg_maxout_if_valid(const weight_t* scores, const int* is_valid,
int n, int nP) nogil:
cdef int best = -1
cdef float best_score = 0
for i in range(n):
if is_valid[i] >= 1:
for j in range(nP):
if best == -1 or scores[i*nP+j] > best_score:
best = i
best_score = scores[i*nP+j]
return best
cdef int _arg_max_clas(const weight_t* scores, int move, const Transition* actions,
int nr_class) except -1:
cdef weight_t score = 0
cdef int mode = -1
cdef int i
for i in range(nr_class):
if actions[i].move == move and (mode == -1 or scores[i] >= score):
mode = i
score = scores[i]
return mode
# These are passed as callbacks to thinc.search.Beam
cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1:
dest = <StateClass>_dest

View File

@ -148,7 +148,8 @@ cdef class TransitionSystem:
def add_action(self, int action, label_name):
cdef attr_t label_id
if not isinstance(label_name, (int, long)):
if not isinstance(label_name, int) and \
not isinstance(label_name, long):
label_id = self.strings.add(label_name)
else:
label_id = label_name

View File

@ -315,30 +315,30 @@ p
+cell Number of rows in embedding tables.
+cell #[code 7500]
+row
+cell #[code parser_maxout_pieces]
+cell Number of pieces in the parser's and NER's first maxout layer.
+cell #[code 2]
//- +row
//- +cell #[code parser_maxout_pieces]
//- +cell Number of pieces in the parser's and NER's first maxout layer.
//- +cell #[code 2]
+row
+cell #[code parser_hidden_depth]
+cell Number of hidden layers in the parser and NER.
+cell #[code 1]
//- +row
//- +cell #[code parser_hidden_depth]
//- +cell Number of hidden layers in the parser and NER.
//- +cell #[code 1]
+row
+cell #[code hidden_width]
+cell Size of the parser's and NER's hidden layers.
+cell #[code 128]
+row
+cell #[code history_feats]
+cell Number of previous action ID features for parser and NER.
+cell #[code 128]
//- +row
//- +cell #[code history_feats]
//- +cell Number of previous action ID features for parser and NER.
//- +cell #[code 128]
+row
+cell #[code history_width]
+cell Number of embedding dimensions for each action ID.
+cell #[code 128]
//- +row
//- +cell #[code history_width]
//- +cell Number of embedding dimensions for each action ID.
//- +cell #[code 128]
+row
+cell #[code learn_rate]