Merge pull request #1438 from explosion/feature/fast-parser

💫 Improve runtime CPU efficiency of parser/NER
This commit is contained in:
Matthew Honnibal 2017-10-19 02:42:21 +02:00 committed by GitHub
commit 61bc203f3f
6 changed files with 123 additions and 190 deletions

View File

@ -53,7 +53,8 @@ MOD_NAMES = [
COMPILE_OPTIONS = { COMPILE_OPTIONS = {
'msvc': ['/Ox', '/EHsc'], 'msvc': ['/Ox', '/EHsc'],
'mingw32' : ['-O3', '-Wno-strict-prototypes', '-Wno-unused-function'], 'mingw32' : ['-O3', '-Wno-strict-prototypes', '-Wno-unused-function'],
'other' : ['-O3', '-Wno-strict-prototypes', '-Wno-unused-function'] 'other' : ['-O3', '-Wno-strict-prototypes', '-Wno-unused-function',
'-march=native']
} }

View File

@ -10,6 +10,7 @@ from collections import OrderedDict
import itertools import itertools
import weakref import weakref
import functools import functools
import tqdm
from .tokenizer import Tokenizer from .tokenizer import Tokenizer
from .vocab import Vocab from .vocab import Vocab
@ -447,11 +448,9 @@ class Language(object):
golds = list(golds) golds = list(golds)
for name, pipe in self.pipeline: for name, pipe in self.pipeline:
if not hasattr(pipe, 'pipe'): if not hasattr(pipe, 'pipe'):
for doc in docs: docs = (pipe(doc) for doc in docs)
pipe(doc)
else: else:
docs = list(pipe.pipe(docs)) docs = pipe.pipe(docs, batch_size=256)
assert len(docs) == len(golds)
for doc, gold in zip(docs, golds): for doc, gold in zip(docs, golds):
if verbose: if verbose:
print(doc) print(doc)

View File

@ -15,8 +15,6 @@ cdef class Parser:
cdef readonly object cfg cdef readonly object cfg
cdef public object _multitasks cdef public object _multitasks
cdef void _parse_step(self, StateC* state, cdef void _parseC(self, StateC* state,
const float* feat_weights, const float* feat_weights, const float* hW, const float* hb,
int nr_class, int nr_feat, int nr_piece) nogil int nr_class, int nr_hidden, int nr_feat, int nr_piece) nogil
#cdef int parseC(self, TokenC* tokens, int length, int nr_feat) nogil

View File

@ -9,6 +9,7 @@ from collections import Counter, OrderedDict
import ujson import ujson
import json import json
import contextlib import contextlib
import numpy
from libc.math cimport exp from libc.math cimport exp
cimport cython cimport cython
@ -27,7 +28,7 @@ from libc.string cimport memset, memcpy
from libc.stdlib cimport malloc, calloc, free from libc.stdlib cimport malloc, calloc, free
from thinc.typedefs cimport weight_t, class_t, feat_t, atom_t, hash_t from thinc.typedefs cimport weight_t, class_t, feat_t, atom_t, hash_t
from thinc.linear.avgtron cimport AveragedPerceptron from thinc.linear.avgtron cimport AveragedPerceptron
from thinc.linalg cimport VecVec from thinc.linalg cimport Vec, VecVec
from thinc.structs cimport SparseArrayC, FeatureC, ExampleC from thinc.structs cimport SparseArrayC, FeatureC, ExampleC
from thinc.extra.eg cimport Example from thinc.extra.eg cimport Example
from thinc.extra.search cimport Beam from thinc.extra.search cimport Beam
@ -37,7 +38,7 @@ from murmurhash.mrmr cimport hash64
from preshed.maps cimport MapStruct from preshed.maps cimport MapStruct
from preshed.maps cimport map_get from preshed.maps cimport map_get
from thinc.api import layerize, chain, noop, clone, with_flatten from thinc.api import layerize, chain, clone, with_flatten
from thinc.v2v import Model, Maxout, Softmax, Affine, ReLu, SELU from thinc.v2v import Model, Maxout, Softmax, Affine, ReLu, SELU
from thinc.misc import LayerNorm from thinc.misc import LayerNorm
@ -240,54 +241,32 @@ cdef class Parser:
@classmethod @classmethod
def Model(cls, nr_class, **cfg): def Model(cls, nr_class, **cfg):
depth = util.env_opt('parser_hidden_depth', cfg.get('hidden_depth', 1)) depth = util.env_opt('parser_hidden_depth', cfg.get('hidden_depth', 1))
if depth != 1:
raise ValueError("Currently parser depth is hard-coded to 1.")
parser_maxout_pieces = util.env_opt('parser_maxout_pieces', cfg.get('maxout_pieces', 2))
if parser_maxout_pieces != 2:
raise ValueError("Currently parser_maxout_pieces is hard-coded to 2")
token_vector_width = util.env_opt('token_vector_width', cfg.get('token_vector_width', 128)) token_vector_width = util.env_opt('token_vector_width', cfg.get('token_vector_width', 128))
hidden_width = util.env_opt('hidden_width', cfg.get('hidden_width', 200)) hidden_width = util.env_opt('hidden_width', cfg.get('hidden_width', 200))
parser_maxout_pieces = util.env_opt('parser_maxout_pieces', cfg.get('maxout_pieces', 2))
embed_size = util.env_opt('embed_size', cfg.get('embed_size', 7000)) embed_size = util.env_opt('embed_size', cfg.get('embed_size', 7000))
hist_size = util.env_opt('history_feats', cfg.get('hist_size', 0)) hist_size = util.env_opt('history_feats', cfg.get('hist_size', 0))
hist_width = util.env_opt('history_width', cfg.get('hist_width', 0)) hist_width = util.env_opt('history_width', cfg.get('hist_width', 0))
if hist_size >= 1 and depth == 0: if hist_size != 0:
raise ValueError("Inconsistent hyper-params: " raise ValueError("Currently history size is hard-coded to 0")
"history_feats >= 1 but parser_hidden_depth==0") if hist_width != 0:
raise ValueError("Currently history width is hard-coded to 0")
tok2vec = Tok2Vec(token_vector_width, embed_size, tok2vec = Tok2Vec(token_vector_width, embed_size,
pretrained_dims=cfg.get('pretrained_dims', 0)) pretrained_dims=cfg.get('pretrained_dims', 0))
tok2vec = chain(tok2vec, flatten) tok2vec = chain(tok2vec, flatten)
if parser_maxout_pieces == 1: lower = PrecomputableMaxouts(hidden_width if depth >= 1 else nr_class,
lower = PrecomputableAffine(hidden_width if depth >= 1 else nr_class, nF=cls.nr_feature, nP=parser_maxout_pieces,
nF=cls.nr_feature, nI=token_vector_width)
nI=token_vector_width)
else:
lower = PrecomputableMaxouts(hidden_width if depth >= 1 else nr_class,
nF=cls.nr_feature,
nP=parser_maxout_pieces,
nI=token_vector_width)
with Model.use_device('cpu'): with Model.use_device('cpu'):
if depth == 0: upper = chain(
upper = chain() clone(LayerNorm(Maxout(hidden_width, hidden_width)), depth-1),
upper.is_noop = True zero_init(Affine(nr_class, hidden_width, drop_factor=0.0))
elif hist_size and depth == 1: )
upper = chain(
HistoryFeatures(nr_class=nr_class, hist_size=hist_size,
nr_dim=hist_width),
zero_init(Affine(nr_class, hidden_width+hist_size*hist_width,
drop_factor=0.0)))
upper.is_noop = False
elif hist_size:
upper = chain(
HistoryFeatures(nr_class=nr_class, hist_size=hist_size,
nr_dim=hist_width),
LayerNorm(Maxout(hidden_width, hidden_width+hist_size*hist_width)),
clone(LayerNorm(Maxout(hidden_width, hidden_width)), depth-2),
zero_init(Affine(nr_class, hidden_width, drop_factor=0.0))
)
upper.is_noop = False
else:
upper = chain(
clone(LayerNorm(Maxout(hidden_width, hidden_width)), depth-1),
zero_init(Affine(nr_class, hidden_width, drop_factor=0.0))
)
upper.is_noop = False
# TODO: This is an unfortunate hack atm! # TODO: This is an unfortunate hack atm!
# Used to set input dimensions in network. # Used to set input dimensions in network.
@ -391,90 +370,100 @@ cdef class Parser:
beam_density = self.cfg.get('beam_density', 0.0) beam_density = self.cfg.get('beam_density', 0.0)
cdef Doc doc cdef Doc doc
cdef Beam beam cdef Beam beam
for docs in cytoolz.partition_all(batch_size, docs): for batch in cytoolz.partition_all(batch_size, docs):
docs = list(docs) batch = list(batch)
if beam_width == 1: by_length = sorted(list(batch), key=lambda doc: len(doc))
parse_states = self.parse_batch(docs) for subbatch in cytoolz.partition_all(8, by_length):
beams = [] subbatch = list(subbatch)
else: if beam_width == 1:
beams = self.beam_parse(docs, parse_states = self.parse_batch(subbatch)
beam_width=beam_width, beam_density=beam_density) beams = []
parse_states = [] else:
for beam in beams: beams = self.beam_parse(subbatch,
parse_states.append(<StateClass>beam.at(0)) beam_width=beam_width, beam_density=beam_density)
self.set_annotations(docs, parse_states) parse_states = []
yield from docs for beam in beams:
parse_states.append(<StateClass>beam.at(0))
self.set_annotations(subbatch, parse_states)
yield from batch
def parse_batch(self, docs): def parse_batch(self, docs):
cdef: cdef:
precompute_hiddens state2vec precompute_hiddens state2vec
StateClass state StateClass stcls
Pool mem Pool mem
const float* feat_weights const float* feat_weights
StateC* st StateC* st
vector[StateC*] next_step, this_step vector[StateC*] states
int nr_class, nr_feat, nr_piece, nr_dim, nr_state int guess, nr_class, nr_feat, nr_piece, nr_dim, nr_state, nr_step
int j
if isinstance(docs, Doc): if isinstance(docs, Doc):
docs = [docs] docs = [docs]
cuda_stream = get_cuda_stream() cuda_stream = get_cuda_stream()
(tokvecs, bp_tokvecs), state2vec, vec2scores = self.get_batch_model(docs, cuda_stream, (tokvecs, bp_tokvecs), state2vec, vec2scores = self.get_batch_model(docs, cuda_stream,
0.0) 0.0)
nr_state = len(docs) nr_state = len(docs)
nr_class = self.moves.n_moves nr_class = self.moves.n_moves
nr_dim = tokvecs.shape[1] nr_dim = tokvecs.shape[1]
nr_feat = self.nr_feature nr_feat = self.nr_feature
nr_piece = state2vec.nP nr_piece = state2vec.nP
states = self.moves.init_batch(docs) state_objs = self.moves.init_batch(docs)
for state in states: for stcls in state_objs:
if not state.c.is_final(): if not stcls.c.is_final():
next_step.push_back(state.c) states.push_back(stcls.c)
feat_weights = state2vec.get_feat_weights() feat_weights = state2vec.get_feat_weights()
cdef int i cdef int i
cdef np.ndarray token_ids = numpy.zeros((nr_state, nr_feat), dtype='i') cdef np.ndarray hidden_weights = numpy.ascontiguousarray(vec2scores._layers[-1].W.T)
cdef np.ndarray is_valid = numpy.zeros((nr_state, nr_class), dtype='i') cdef np.ndarray hidden_bias = vec2scores._layers[-1].b
cdef np.ndarray scores
c_token_ids = <int*>token_ids.data hW = <float*>hidden_weights.data
c_is_valid = <int*>is_valid.data hb = <float*>hidden_bias.data
cdef int has_hidden = not getattr(vec2scores, 'is_noop', False) cdef int nr_hidden = hidden_weights.shape[0]
cdef int nr_step cdef int nr_task = states.size()
while not next_step.empty(): with nogil:
nr_step = next_step.size() for i in cython.parallel.prange(nr_task, num_threads=2,
if not has_hidden: schedule='guided'):
for i in cython.parallel.prange(nr_step, num_threads=6, self._parseC(states[i],
nogil=True): feat_weights, hW, hb,
self._parse_step(next_step[i], nr_class, nr_hidden, nr_feat, nr_piece)
feat_weights, nr_class, nr_feat, nr_piece) return state_objs
else:
hists = [] cdef void _parseC(self, StateC* state,
for i in range(nr_step): const float* feat_weights, const float* hW, const float* hb,
st = next_step[i] int nr_class, int nr_hidden, int nr_feat, int nr_piece) nogil:
st.set_context_tokens(&c_token_ids[i*nr_feat], nr_feat) token_ids = <int*>calloc(nr_feat, sizeof(int))
self.moves.set_valid(&c_is_valid[i*nr_class], st) is_valid = <int*>calloc(nr_class, sizeof(int))
hists.append([st.get_hist(j+1) for j in range(8)]) vectors = <float*>calloc(nr_hidden * nr_piece, sizeof(float))
hists = numpy.asarray(hists) scores = <float*>calloc(nr_class, sizeof(float))
vectors = state2vec(token_ids[:next_step.size()])
if self.cfg.get('hist_size'): while not state.is_final():
scores = vec2scores((vectors, hists)) state.set_context_tokens(token_ids, nr_feat)
else: memset(vectors, 0, nr_hidden * nr_piece * sizeof(float))
scores = vec2scores(vectors) memset(scores, 0, nr_class * sizeof(float))
c_scores = <float*>scores.data sum_state_features(vectors,
for i in range(nr_step): feat_weights, token_ids, 1, nr_feat, nr_hidden * nr_piece)
st = next_step[i] V = vectors
guess = arg_max_if_valid( W = hW
&c_scores[i*nr_class], &c_is_valid[i*nr_class], nr_class) for i in range(nr_hidden):
action = self.moves.c[guess] feature = V[0] if V[0] >= V[1] else V[1]
action.do(st, action.label) for j in range(nr_class):
st.push_hist(guess) scores[j] += feature * W[j]
this_step, next_step = next_step, this_step W += nr_class
next_step.clear() V += nr_piece
for st in this_step: for i in range(nr_class):
if not st.is_final(): scores[i] += hb[i]
next_step.push_back(st) self.moves.set_valid(is_valid, state)
return states guess = arg_max_if_valid(scores, is_valid, nr_class)
action = self.moves.c[guess]
action.do(state, action.label)
state.push_hist(guess)
free(token_ids)
free(is_valid)
free(vectors)
free(scores)
def beam_parse(self, docs, int beam_width=3, float beam_density=0.001): def beam_parse(self, docs, int beam_width=3, float beam_density=0.001):
cdef Beam beam cdef Beam beam
@ -527,27 +516,6 @@ cdef class Parser:
beams.append(beam) beams.append(beam)
return beams return beams
cdef void _parse_step(self, StateC* state,
const float* feat_weights,
int nr_class, int nr_feat, int nr_piece) nogil:
'''This only works with no hidden layers -- fast but inaccurate'''
token_ids = <int*>calloc(nr_feat, sizeof(int))
scores = <float*>calloc(nr_class * nr_piece, sizeof(float))
is_valid = <int*>calloc(nr_class, sizeof(int))
state.set_context_tokens(token_ids, nr_feat)
sum_state_features(scores,
feat_weights, token_ids, 1, nr_feat, nr_class * nr_piece)
self.moves.set_valid(is_valid, state)
guess = arg_maxout_if_valid(scores, is_valid, nr_class, nr_piece)
action = self.moves.c[guess]
action.do(state, action.label)
state.push_hist(guess)
free(is_valid)
free(scores)
free(token_ids)
def update(self, docs, golds, drop=0., sgd=None, losses=None): def update(self, docs, golds, drop=0., sgd=None, losses=None):
if not any(self.moves.has_gold(gold) for gold in golds): if not any(self.moves.has_gold(gold) for gold in golds):
return None return None
@ -800,20 +768,11 @@ cdef class Parser:
if self.model not in (True, False, None) and resized: if self.model not in (True, False, None) and resized:
# Weights are stored in (nr_out, nr_in) format, so we're basically # Weights are stored in (nr_out, nr_in) format, so we're basically
# just adding rows here. # just adding rows here.
if self.model[-1].is_noop: smaller = self.model[-1]._layers[-1]
smaller = self.model[1] larger = Affine(self.moves.n_moves, smaller.nI)
dims = dict(self.model[1]._dims) copy_array(larger.W[:smaller.nO], smaller.W)
dims['nO'] = self.moves.n_moves copy_array(larger.b[:smaller.nO], smaller.b)
larger = self.model[1].__class__(**dims) self.model[-1]._layers[-1] = larger
copy_array(larger.W[:, :smaller.nO], smaller.W)
copy_array(larger.b[:smaller.nO], smaller.b)
self.model = (self.model[0], larger, self.model[2])
else:
smaller = self.model[-1]._layers[-1]
larger = Affine(self.moves.n_moves, smaller.nI)
copy_array(larger.W[:smaller.nO], smaller.W)
copy_array(larger.b[:smaller.nO], smaller.b)
self.model[-1]._layers[-1] = larger
def begin_training(self, gold_tuples, pipeline=None, **cfg): def begin_training(self, gold_tuples, pipeline=None, **cfg):
if 'model' in cfg: if 'model' in cfg:
@ -969,31 +928,6 @@ cdef int arg_max_if_valid(const weight_t* scores, const int* is_valid, int n) no
return best return best
cdef int arg_maxout_if_valid(const weight_t* scores, const int* is_valid,
int n, int nP) nogil:
cdef int best = -1
cdef float best_score = 0
for i in range(n):
if is_valid[i] >= 1:
for j in range(nP):
if best == -1 or scores[i*nP+j] > best_score:
best = i
best_score = scores[i*nP+j]
return best
cdef int _arg_max_clas(const weight_t* scores, int move, const Transition* actions,
int nr_class) except -1:
cdef weight_t score = 0
cdef int mode = -1
cdef int i
for i in range(nr_class):
if actions[i].move == move and (mode == -1 or scores[i] >= score):
mode = i
score = scores[i]
return mode
# These are passed as callbacks to thinc.search.Beam # These are passed as callbacks to thinc.search.Beam
cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1: cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1:
dest = <StateClass>_dest dest = <StateClass>_dest

View File

@ -148,7 +148,8 @@ cdef class TransitionSystem:
def add_action(self, int action, label_name): def add_action(self, int action, label_name):
cdef attr_t label_id cdef attr_t label_id
if not isinstance(label_name, (int, long)): if not isinstance(label_name, int) and \
not isinstance(label_name, long):
label_id = self.strings.add(label_name) label_id = self.strings.add(label_name)
else: else:
label_id = label_name label_id = label_name

View File

@ -315,30 +315,30 @@ p
+cell Number of rows in embedding tables. +cell Number of rows in embedding tables.
+cell #[code 7500] +cell #[code 7500]
+row //- +row
+cell #[code parser_maxout_pieces] //- +cell #[code parser_maxout_pieces]
+cell Number of pieces in the parser's and NER's first maxout layer. //- +cell Number of pieces in the parser's and NER's first maxout layer.
+cell #[code 2] //- +cell #[code 2]
+row //- +row
+cell #[code parser_hidden_depth] //- +cell #[code parser_hidden_depth]
+cell Number of hidden layers in the parser and NER. //- +cell Number of hidden layers in the parser and NER.
+cell #[code 1] //- +cell #[code 1]
+row +row
+cell #[code hidden_width] +cell #[code hidden_width]
+cell Size of the parser's and NER's hidden layers. +cell Size of the parser's and NER's hidden layers.
+cell #[code 128] +cell #[code 128]
+row //- +row
+cell #[code history_feats] //- +cell #[code history_feats]
+cell Number of previous action ID features for parser and NER. //- +cell Number of previous action ID features for parser and NER.
+cell #[code 128] //- +cell #[code 128]
+row //- +row
+cell #[code history_width] //- +cell #[code history_width]
+cell Number of embedding dimensions for each action ID. //- +cell Number of embedding dimensions for each action ID.
+cell #[code 128] //- +cell #[code 128]
+row +row
+cell #[code learn_rate] +cell #[code learn_rate]