mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-12 09:12:21 +03:00
Merge pull request #1438 from explosion/feature/fast-parser
💫 Improve runtime CPU efficiency of parser/NER
This commit is contained in:
commit
61bc203f3f
3
setup.py
3
setup.py
|
@ -53,7 +53,8 @@ MOD_NAMES = [
|
||||||
COMPILE_OPTIONS = {
|
COMPILE_OPTIONS = {
|
||||||
'msvc': ['/Ox', '/EHsc'],
|
'msvc': ['/Ox', '/EHsc'],
|
||||||
'mingw32' : ['-O3', '-Wno-strict-prototypes', '-Wno-unused-function'],
|
'mingw32' : ['-O3', '-Wno-strict-prototypes', '-Wno-unused-function'],
|
||||||
'other' : ['-O3', '-Wno-strict-prototypes', '-Wno-unused-function']
|
'other' : ['-O3', '-Wno-strict-prototypes', '-Wno-unused-function',
|
||||||
|
'-march=native']
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -10,6 +10,7 @@ from collections import OrderedDict
|
||||||
import itertools
|
import itertools
|
||||||
import weakref
|
import weakref
|
||||||
import functools
|
import functools
|
||||||
|
import tqdm
|
||||||
|
|
||||||
from .tokenizer import Tokenizer
|
from .tokenizer import Tokenizer
|
||||||
from .vocab import Vocab
|
from .vocab import Vocab
|
||||||
|
@ -447,11 +448,9 @@ class Language(object):
|
||||||
golds = list(golds)
|
golds = list(golds)
|
||||||
for name, pipe in self.pipeline:
|
for name, pipe in self.pipeline:
|
||||||
if not hasattr(pipe, 'pipe'):
|
if not hasattr(pipe, 'pipe'):
|
||||||
for doc in docs:
|
docs = (pipe(doc) for doc in docs)
|
||||||
pipe(doc)
|
|
||||||
else:
|
else:
|
||||||
docs = list(pipe.pipe(docs))
|
docs = pipe.pipe(docs, batch_size=256)
|
||||||
assert len(docs) == len(golds)
|
|
||||||
for doc, gold in zip(docs, golds):
|
for doc, gold in zip(docs, golds):
|
||||||
if verbose:
|
if verbose:
|
||||||
print(doc)
|
print(doc)
|
||||||
|
|
|
@ -15,8 +15,6 @@ cdef class Parser:
|
||||||
cdef readonly object cfg
|
cdef readonly object cfg
|
||||||
cdef public object _multitasks
|
cdef public object _multitasks
|
||||||
|
|
||||||
cdef void _parse_step(self, StateC* state,
|
cdef void _parseC(self, StateC* state,
|
||||||
const float* feat_weights,
|
const float* feat_weights, const float* hW, const float* hb,
|
||||||
int nr_class, int nr_feat, int nr_piece) nogil
|
int nr_class, int nr_hidden, int nr_feat, int nr_piece) nogil
|
||||||
|
|
||||||
#cdef int parseC(self, TokenC* tokens, int length, int nr_feat) nogil
|
|
||||||
|
|
|
@ -9,6 +9,7 @@ from collections import Counter, OrderedDict
|
||||||
import ujson
|
import ujson
|
||||||
import json
|
import json
|
||||||
import contextlib
|
import contextlib
|
||||||
|
import numpy
|
||||||
|
|
||||||
from libc.math cimport exp
|
from libc.math cimport exp
|
||||||
cimport cython
|
cimport cython
|
||||||
|
@ -27,7 +28,7 @@ from libc.string cimport memset, memcpy
|
||||||
from libc.stdlib cimport malloc, calloc, free
|
from libc.stdlib cimport malloc, calloc, free
|
||||||
from thinc.typedefs cimport weight_t, class_t, feat_t, atom_t, hash_t
|
from thinc.typedefs cimport weight_t, class_t, feat_t, atom_t, hash_t
|
||||||
from thinc.linear.avgtron cimport AveragedPerceptron
|
from thinc.linear.avgtron cimport AveragedPerceptron
|
||||||
from thinc.linalg cimport VecVec
|
from thinc.linalg cimport Vec, VecVec
|
||||||
from thinc.structs cimport SparseArrayC, FeatureC, ExampleC
|
from thinc.structs cimport SparseArrayC, FeatureC, ExampleC
|
||||||
from thinc.extra.eg cimport Example
|
from thinc.extra.eg cimport Example
|
||||||
from thinc.extra.search cimport Beam
|
from thinc.extra.search cimport Beam
|
||||||
|
@ -37,7 +38,7 @@ from murmurhash.mrmr cimport hash64
|
||||||
from preshed.maps cimport MapStruct
|
from preshed.maps cimport MapStruct
|
||||||
from preshed.maps cimport map_get
|
from preshed.maps cimport map_get
|
||||||
|
|
||||||
from thinc.api import layerize, chain, noop, clone, with_flatten
|
from thinc.api import layerize, chain, clone, with_flatten
|
||||||
from thinc.v2v import Model, Maxout, Softmax, Affine, ReLu, SELU
|
from thinc.v2v import Model, Maxout, Softmax, Affine, ReLu, SELU
|
||||||
from thinc.misc import LayerNorm
|
from thinc.misc import LayerNorm
|
||||||
|
|
||||||
|
@ -240,54 +241,32 @@ cdef class Parser:
|
||||||
@classmethod
|
@classmethod
|
||||||
def Model(cls, nr_class, **cfg):
|
def Model(cls, nr_class, **cfg):
|
||||||
depth = util.env_opt('parser_hidden_depth', cfg.get('hidden_depth', 1))
|
depth = util.env_opt('parser_hidden_depth', cfg.get('hidden_depth', 1))
|
||||||
|
if depth != 1:
|
||||||
|
raise ValueError("Currently parser depth is hard-coded to 1.")
|
||||||
|
parser_maxout_pieces = util.env_opt('parser_maxout_pieces', cfg.get('maxout_pieces', 2))
|
||||||
|
if parser_maxout_pieces != 2:
|
||||||
|
raise ValueError("Currently parser_maxout_pieces is hard-coded to 2")
|
||||||
token_vector_width = util.env_opt('token_vector_width', cfg.get('token_vector_width', 128))
|
token_vector_width = util.env_opt('token_vector_width', cfg.get('token_vector_width', 128))
|
||||||
hidden_width = util.env_opt('hidden_width', cfg.get('hidden_width', 200))
|
hidden_width = util.env_opt('hidden_width', cfg.get('hidden_width', 200))
|
||||||
parser_maxout_pieces = util.env_opt('parser_maxout_pieces', cfg.get('maxout_pieces', 2))
|
|
||||||
embed_size = util.env_opt('embed_size', cfg.get('embed_size', 7000))
|
embed_size = util.env_opt('embed_size', cfg.get('embed_size', 7000))
|
||||||
hist_size = util.env_opt('history_feats', cfg.get('hist_size', 0))
|
hist_size = util.env_opt('history_feats', cfg.get('hist_size', 0))
|
||||||
hist_width = util.env_opt('history_width', cfg.get('hist_width', 0))
|
hist_width = util.env_opt('history_width', cfg.get('hist_width', 0))
|
||||||
if hist_size >= 1 and depth == 0:
|
if hist_size != 0:
|
||||||
raise ValueError("Inconsistent hyper-params: "
|
raise ValueError("Currently history size is hard-coded to 0")
|
||||||
"history_feats >= 1 but parser_hidden_depth==0")
|
if hist_width != 0:
|
||||||
|
raise ValueError("Currently history width is hard-coded to 0")
|
||||||
tok2vec = Tok2Vec(token_vector_width, embed_size,
|
tok2vec = Tok2Vec(token_vector_width, embed_size,
|
||||||
pretrained_dims=cfg.get('pretrained_dims', 0))
|
pretrained_dims=cfg.get('pretrained_dims', 0))
|
||||||
tok2vec = chain(tok2vec, flatten)
|
tok2vec = chain(tok2vec, flatten)
|
||||||
if parser_maxout_pieces == 1:
|
lower = PrecomputableMaxouts(hidden_width if depth >= 1 else nr_class,
|
||||||
lower = PrecomputableAffine(hidden_width if depth >= 1 else nr_class,
|
nF=cls.nr_feature, nP=parser_maxout_pieces,
|
||||||
nF=cls.nr_feature,
|
nI=token_vector_width)
|
||||||
nI=token_vector_width)
|
|
||||||
else:
|
|
||||||
lower = PrecomputableMaxouts(hidden_width if depth >= 1 else nr_class,
|
|
||||||
nF=cls.nr_feature,
|
|
||||||
nP=parser_maxout_pieces,
|
|
||||||
nI=token_vector_width)
|
|
||||||
|
|
||||||
with Model.use_device('cpu'):
|
with Model.use_device('cpu'):
|
||||||
if depth == 0:
|
upper = chain(
|
||||||
upper = chain()
|
clone(LayerNorm(Maxout(hidden_width, hidden_width)), depth-1),
|
||||||
upper.is_noop = True
|
zero_init(Affine(nr_class, hidden_width, drop_factor=0.0))
|
||||||
elif hist_size and depth == 1:
|
)
|
||||||
upper = chain(
|
|
||||||
HistoryFeatures(nr_class=nr_class, hist_size=hist_size,
|
|
||||||
nr_dim=hist_width),
|
|
||||||
zero_init(Affine(nr_class, hidden_width+hist_size*hist_width,
|
|
||||||
drop_factor=0.0)))
|
|
||||||
upper.is_noop = False
|
|
||||||
elif hist_size:
|
|
||||||
upper = chain(
|
|
||||||
HistoryFeatures(nr_class=nr_class, hist_size=hist_size,
|
|
||||||
nr_dim=hist_width),
|
|
||||||
LayerNorm(Maxout(hidden_width, hidden_width+hist_size*hist_width)),
|
|
||||||
clone(LayerNorm(Maxout(hidden_width, hidden_width)), depth-2),
|
|
||||||
zero_init(Affine(nr_class, hidden_width, drop_factor=0.0))
|
|
||||||
)
|
|
||||||
upper.is_noop = False
|
|
||||||
else:
|
|
||||||
upper = chain(
|
|
||||||
clone(LayerNorm(Maxout(hidden_width, hidden_width)), depth-1),
|
|
||||||
zero_init(Affine(nr_class, hidden_width, drop_factor=0.0))
|
|
||||||
)
|
|
||||||
upper.is_noop = False
|
|
||||||
|
|
||||||
# TODO: This is an unfortunate hack atm!
|
# TODO: This is an unfortunate hack atm!
|
||||||
# Used to set input dimensions in network.
|
# Used to set input dimensions in network.
|
||||||
|
@ -391,90 +370,100 @@ cdef class Parser:
|
||||||
beam_density = self.cfg.get('beam_density', 0.0)
|
beam_density = self.cfg.get('beam_density', 0.0)
|
||||||
cdef Doc doc
|
cdef Doc doc
|
||||||
cdef Beam beam
|
cdef Beam beam
|
||||||
for docs in cytoolz.partition_all(batch_size, docs):
|
for batch in cytoolz.partition_all(batch_size, docs):
|
||||||
docs = list(docs)
|
batch = list(batch)
|
||||||
if beam_width == 1:
|
by_length = sorted(list(batch), key=lambda doc: len(doc))
|
||||||
parse_states = self.parse_batch(docs)
|
for subbatch in cytoolz.partition_all(8, by_length):
|
||||||
beams = []
|
subbatch = list(subbatch)
|
||||||
else:
|
if beam_width == 1:
|
||||||
beams = self.beam_parse(docs,
|
parse_states = self.parse_batch(subbatch)
|
||||||
beam_width=beam_width, beam_density=beam_density)
|
beams = []
|
||||||
parse_states = []
|
else:
|
||||||
for beam in beams:
|
beams = self.beam_parse(subbatch,
|
||||||
parse_states.append(<StateClass>beam.at(0))
|
beam_width=beam_width, beam_density=beam_density)
|
||||||
self.set_annotations(docs, parse_states)
|
parse_states = []
|
||||||
yield from docs
|
for beam in beams:
|
||||||
|
parse_states.append(<StateClass>beam.at(0))
|
||||||
|
self.set_annotations(subbatch, parse_states)
|
||||||
|
yield from batch
|
||||||
|
|
||||||
def parse_batch(self, docs):
|
def parse_batch(self, docs):
|
||||||
cdef:
|
cdef:
|
||||||
precompute_hiddens state2vec
|
precompute_hiddens state2vec
|
||||||
StateClass state
|
StateClass stcls
|
||||||
Pool mem
|
Pool mem
|
||||||
const float* feat_weights
|
const float* feat_weights
|
||||||
StateC* st
|
StateC* st
|
||||||
vector[StateC*] next_step, this_step
|
vector[StateC*] states
|
||||||
int nr_class, nr_feat, nr_piece, nr_dim, nr_state
|
int guess, nr_class, nr_feat, nr_piece, nr_dim, nr_state, nr_step
|
||||||
|
int j
|
||||||
if isinstance(docs, Doc):
|
if isinstance(docs, Doc):
|
||||||
docs = [docs]
|
docs = [docs]
|
||||||
|
|
||||||
cuda_stream = get_cuda_stream()
|
cuda_stream = get_cuda_stream()
|
||||||
(tokvecs, bp_tokvecs), state2vec, vec2scores = self.get_batch_model(docs, cuda_stream,
|
(tokvecs, bp_tokvecs), state2vec, vec2scores = self.get_batch_model(docs, cuda_stream,
|
||||||
0.0)
|
0.0)
|
||||||
|
|
||||||
nr_state = len(docs)
|
nr_state = len(docs)
|
||||||
nr_class = self.moves.n_moves
|
nr_class = self.moves.n_moves
|
||||||
nr_dim = tokvecs.shape[1]
|
nr_dim = tokvecs.shape[1]
|
||||||
nr_feat = self.nr_feature
|
nr_feat = self.nr_feature
|
||||||
nr_piece = state2vec.nP
|
nr_piece = state2vec.nP
|
||||||
|
|
||||||
states = self.moves.init_batch(docs)
|
state_objs = self.moves.init_batch(docs)
|
||||||
for state in states:
|
for stcls in state_objs:
|
||||||
if not state.c.is_final():
|
if not stcls.c.is_final():
|
||||||
next_step.push_back(state.c)
|
states.push_back(stcls.c)
|
||||||
|
|
||||||
feat_weights = state2vec.get_feat_weights()
|
feat_weights = state2vec.get_feat_weights()
|
||||||
cdef int i
|
cdef int i
|
||||||
cdef np.ndarray token_ids = numpy.zeros((nr_state, nr_feat), dtype='i')
|
cdef np.ndarray hidden_weights = numpy.ascontiguousarray(vec2scores._layers[-1].W.T)
|
||||||
cdef np.ndarray is_valid = numpy.zeros((nr_state, nr_class), dtype='i')
|
cdef np.ndarray hidden_bias = vec2scores._layers[-1].b
|
||||||
cdef np.ndarray scores
|
|
||||||
c_token_ids = <int*>token_ids.data
|
hW = <float*>hidden_weights.data
|
||||||
c_is_valid = <int*>is_valid.data
|
hb = <float*>hidden_bias.data
|
||||||
cdef int has_hidden = not getattr(vec2scores, 'is_noop', False)
|
cdef int nr_hidden = hidden_weights.shape[0]
|
||||||
cdef int nr_step
|
cdef int nr_task = states.size()
|
||||||
while not next_step.empty():
|
with nogil:
|
||||||
nr_step = next_step.size()
|
for i in cython.parallel.prange(nr_task, num_threads=2,
|
||||||
if not has_hidden:
|
schedule='guided'):
|
||||||
for i in cython.parallel.prange(nr_step, num_threads=6,
|
self._parseC(states[i],
|
||||||
nogil=True):
|
feat_weights, hW, hb,
|
||||||
self._parse_step(next_step[i],
|
nr_class, nr_hidden, nr_feat, nr_piece)
|
||||||
feat_weights, nr_class, nr_feat, nr_piece)
|
return state_objs
|
||||||
else:
|
|
||||||
hists = []
|
cdef void _parseC(self, StateC* state,
|
||||||
for i in range(nr_step):
|
const float* feat_weights, const float* hW, const float* hb,
|
||||||
st = next_step[i]
|
int nr_class, int nr_hidden, int nr_feat, int nr_piece) nogil:
|
||||||
st.set_context_tokens(&c_token_ids[i*nr_feat], nr_feat)
|
token_ids = <int*>calloc(nr_feat, sizeof(int))
|
||||||
self.moves.set_valid(&c_is_valid[i*nr_class], st)
|
is_valid = <int*>calloc(nr_class, sizeof(int))
|
||||||
hists.append([st.get_hist(j+1) for j in range(8)])
|
vectors = <float*>calloc(nr_hidden * nr_piece, sizeof(float))
|
||||||
hists = numpy.asarray(hists)
|
scores = <float*>calloc(nr_class, sizeof(float))
|
||||||
vectors = state2vec(token_ids[:next_step.size()])
|
|
||||||
if self.cfg.get('hist_size'):
|
while not state.is_final():
|
||||||
scores = vec2scores((vectors, hists))
|
state.set_context_tokens(token_ids, nr_feat)
|
||||||
else:
|
memset(vectors, 0, nr_hidden * nr_piece * sizeof(float))
|
||||||
scores = vec2scores(vectors)
|
memset(scores, 0, nr_class * sizeof(float))
|
||||||
c_scores = <float*>scores.data
|
sum_state_features(vectors,
|
||||||
for i in range(nr_step):
|
feat_weights, token_ids, 1, nr_feat, nr_hidden * nr_piece)
|
||||||
st = next_step[i]
|
V = vectors
|
||||||
guess = arg_max_if_valid(
|
W = hW
|
||||||
&c_scores[i*nr_class], &c_is_valid[i*nr_class], nr_class)
|
for i in range(nr_hidden):
|
||||||
action = self.moves.c[guess]
|
feature = V[0] if V[0] >= V[1] else V[1]
|
||||||
action.do(st, action.label)
|
for j in range(nr_class):
|
||||||
st.push_hist(guess)
|
scores[j] += feature * W[j]
|
||||||
this_step, next_step = next_step, this_step
|
W += nr_class
|
||||||
next_step.clear()
|
V += nr_piece
|
||||||
for st in this_step:
|
for i in range(nr_class):
|
||||||
if not st.is_final():
|
scores[i] += hb[i]
|
||||||
next_step.push_back(st)
|
self.moves.set_valid(is_valid, state)
|
||||||
return states
|
guess = arg_max_if_valid(scores, is_valid, nr_class)
|
||||||
|
action = self.moves.c[guess]
|
||||||
|
action.do(state, action.label)
|
||||||
|
state.push_hist(guess)
|
||||||
|
free(token_ids)
|
||||||
|
free(is_valid)
|
||||||
|
free(vectors)
|
||||||
|
free(scores)
|
||||||
|
|
||||||
def beam_parse(self, docs, int beam_width=3, float beam_density=0.001):
|
def beam_parse(self, docs, int beam_width=3, float beam_density=0.001):
|
||||||
cdef Beam beam
|
cdef Beam beam
|
||||||
|
@ -527,27 +516,6 @@ cdef class Parser:
|
||||||
beams.append(beam)
|
beams.append(beam)
|
||||||
return beams
|
return beams
|
||||||
|
|
||||||
cdef void _parse_step(self, StateC* state,
|
|
||||||
const float* feat_weights,
|
|
||||||
int nr_class, int nr_feat, int nr_piece) nogil:
|
|
||||||
'''This only works with no hidden layers -- fast but inaccurate'''
|
|
||||||
token_ids = <int*>calloc(nr_feat, sizeof(int))
|
|
||||||
scores = <float*>calloc(nr_class * nr_piece, sizeof(float))
|
|
||||||
is_valid = <int*>calloc(nr_class, sizeof(int))
|
|
||||||
|
|
||||||
state.set_context_tokens(token_ids, nr_feat)
|
|
||||||
sum_state_features(scores,
|
|
||||||
feat_weights, token_ids, 1, nr_feat, nr_class * nr_piece)
|
|
||||||
self.moves.set_valid(is_valid, state)
|
|
||||||
guess = arg_maxout_if_valid(scores, is_valid, nr_class, nr_piece)
|
|
||||||
action = self.moves.c[guess]
|
|
||||||
action.do(state, action.label)
|
|
||||||
state.push_hist(guess)
|
|
||||||
|
|
||||||
free(is_valid)
|
|
||||||
free(scores)
|
|
||||||
free(token_ids)
|
|
||||||
|
|
||||||
def update(self, docs, golds, drop=0., sgd=None, losses=None):
|
def update(self, docs, golds, drop=0., sgd=None, losses=None):
|
||||||
if not any(self.moves.has_gold(gold) for gold in golds):
|
if not any(self.moves.has_gold(gold) for gold in golds):
|
||||||
return None
|
return None
|
||||||
|
@ -800,20 +768,11 @@ cdef class Parser:
|
||||||
if self.model not in (True, False, None) and resized:
|
if self.model not in (True, False, None) and resized:
|
||||||
# Weights are stored in (nr_out, nr_in) format, so we're basically
|
# Weights are stored in (nr_out, nr_in) format, so we're basically
|
||||||
# just adding rows here.
|
# just adding rows here.
|
||||||
if self.model[-1].is_noop:
|
smaller = self.model[-1]._layers[-1]
|
||||||
smaller = self.model[1]
|
larger = Affine(self.moves.n_moves, smaller.nI)
|
||||||
dims = dict(self.model[1]._dims)
|
copy_array(larger.W[:smaller.nO], smaller.W)
|
||||||
dims['nO'] = self.moves.n_moves
|
copy_array(larger.b[:smaller.nO], smaller.b)
|
||||||
larger = self.model[1].__class__(**dims)
|
self.model[-1]._layers[-1] = larger
|
||||||
copy_array(larger.W[:, :smaller.nO], smaller.W)
|
|
||||||
copy_array(larger.b[:smaller.nO], smaller.b)
|
|
||||||
self.model = (self.model[0], larger, self.model[2])
|
|
||||||
else:
|
|
||||||
smaller = self.model[-1]._layers[-1]
|
|
||||||
larger = Affine(self.moves.n_moves, smaller.nI)
|
|
||||||
copy_array(larger.W[:smaller.nO], smaller.W)
|
|
||||||
copy_array(larger.b[:smaller.nO], smaller.b)
|
|
||||||
self.model[-1]._layers[-1] = larger
|
|
||||||
|
|
||||||
def begin_training(self, gold_tuples, pipeline=None, **cfg):
|
def begin_training(self, gold_tuples, pipeline=None, **cfg):
|
||||||
if 'model' in cfg:
|
if 'model' in cfg:
|
||||||
|
@ -969,31 +928,6 @@ cdef int arg_max_if_valid(const weight_t* scores, const int* is_valid, int n) no
|
||||||
return best
|
return best
|
||||||
|
|
||||||
|
|
||||||
cdef int arg_maxout_if_valid(const weight_t* scores, const int* is_valid,
|
|
||||||
int n, int nP) nogil:
|
|
||||||
cdef int best = -1
|
|
||||||
cdef float best_score = 0
|
|
||||||
for i in range(n):
|
|
||||||
if is_valid[i] >= 1:
|
|
||||||
for j in range(nP):
|
|
||||||
if best == -1 or scores[i*nP+j] > best_score:
|
|
||||||
best = i
|
|
||||||
best_score = scores[i*nP+j]
|
|
||||||
return best
|
|
||||||
|
|
||||||
|
|
||||||
cdef int _arg_max_clas(const weight_t* scores, int move, const Transition* actions,
|
|
||||||
int nr_class) except -1:
|
|
||||||
cdef weight_t score = 0
|
|
||||||
cdef int mode = -1
|
|
||||||
cdef int i
|
|
||||||
for i in range(nr_class):
|
|
||||||
if actions[i].move == move and (mode == -1 or scores[i] >= score):
|
|
||||||
mode = i
|
|
||||||
score = scores[i]
|
|
||||||
return mode
|
|
||||||
|
|
||||||
|
|
||||||
# These are passed as callbacks to thinc.search.Beam
|
# These are passed as callbacks to thinc.search.Beam
|
||||||
cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1:
|
cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1:
|
||||||
dest = <StateClass>_dest
|
dest = <StateClass>_dest
|
||||||
|
|
|
@ -148,7 +148,8 @@ cdef class TransitionSystem:
|
||||||
|
|
||||||
def add_action(self, int action, label_name):
|
def add_action(self, int action, label_name):
|
||||||
cdef attr_t label_id
|
cdef attr_t label_id
|
||||||
if not isinstance(label_name, (int, long)):
|
if not isinstance(label_name, int) and \
|
||||||
|
not isinstance(label_name, long):
|
||||||
label_id = self.strings.add(label_name)
|
label_id = self.strings.add(label_name)
|
||||||
else:
|
else:
|
||||||
label_id = label_name
|
label_id = label_name
|
||||||
|
|
|
@ -315,30 +315,30 @@ p
|
||||||
+cell Number of rows in embedding tables.
|
+cell Number of rows in embedding tables.
|
||||||
+cell #[code 7500]
|
+cell #[code 7500]
|
||||||
|
|
||||||
+row
|
//- +row
|
||||||
+cell #[code parser_maxout_pieces]
|
//- +cell #[code parser_maxout_pieces]
|
||||||
+cell Number of pieces in the parser's and NER's first maxout layer.
|
//- +cell Number of pieces in the parser's and NER's first maxout layer.
|
||||||
+cell #[code 2]
|
//- +cell #[code 2]
|
||||||
|
|
||||||
+row
|
//- +row
|
||||||
+cell #[code parser_hidden_depth]
|
//- +cell #[code parser_hidden_depth]
|
||||||
+cell Number of hidden layers in the parser and NER.
|
//- +cell Number of hidden layers in the parser and NER.
|
||||||
+cell #[code 1]
|
//- +cell #[code 1]
|
||||||
|
|
||||||
+row
|
+row
|
||||||
+cell #[code hidden_width]
|
+cell #[code hidden_width]
|
||||||
+cell Size of the parser's and NER's hidden layers.
|
+cell Size of the parser's and NER's hidden layers.
|
||||||
+cell #[code 128]
|
+cell #[code 128]
|
||||||
|
|
||||||
+row
|
//- +row
|
||||||
+cell #[code history_feats]
|
//- +cell #[code history_feats]
|
||||||
+cell Number of previous action ID features for parser and NER.
|
//- +cell Number of previous action ID features for parser and NER.
|
||||||
+cell #[code 128]
|
//- +cell #[code 128]
|
||||||
|
|
||||||
+row
|
//- +row
|
||||||
+cell #[code history_width]
|
//- +cell #[code history_width]
|
||||||
+cell Number of embedding dimensions for each action ID.
|
//- +cell Number of embedding dimensions for each action ID.
|
||||||
+cell #[code 128]
|
//- +cell #[code 128]
|
||||||
|
|
||||||
+row
|
+row
|
||||||
+cell #[code learn_rate]
|
+cell #[code learn_rate]
|
||||||
|
|
Loading…
Reference in New Issue
Block a user