mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-15 06:09:01 +03:00
Support parser depth=0
This commit is contained in:
parent
eba89f08bd
commit
fef50277d7
|
@ -42,11 +42,17 @@ cdef WeightsC get_c_weights(model) except *:
|
||||||
cdef precompute_hiddens state2vec = model.state2vec
|
cdef precompute_hiddens state2vec = model.state2vec
|
||||||
output.feat_weights = state2vec.get_feat_weights()
|
output.feat_weights = state2vec.get_feat_weights()
|
||||||
output.feat_bias = <const float*>state2vec.bias.data
|
output.feat_bias = <const float*>state2vec.bias.data
|
||||||
cdef np.ndarray vec2scores_W = model.vec2scores.W
|
cdef np.ndarray vec2scores_W
|
||||||
cdef np.ndarray vec2scores_b = model.vec2scores.b
|
cdef np.ndarray vec2scores_b
|
||||||
|
if model.vec2scores is None:
|
||||||
|
output.hidden_weights = NULL
|
||||||
|
output.hidden_bias = NULL
|
||||||
|
else:
|
||||||
|
vec2scores_W = model.vec2scores.b
|
||||||
|
vec2scores_b = model.vec2scores.W
|
||||||
|
output.hidden_weights = <const float*>vec2scores_W.data
|
||||||
|
output.hidden_bias = <const float*>vec2scores_b.data
|
||||||
cdef np.ndarray class_mask = model._class_mask
|
cdef np.ndarray class_mask = model._class_mask
|
||||||
output.hidden_weights = <const float*>vec2scores_W.data
|
|
||||||
output.hidden_bias = <const float*>vec2scores_b.data
|
|
||||||
output.seen_classes = <const float*>class_mask.data
|
output.seen_classes = <const float*>class_mask.data
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
@ -54,7 +60,10 @@ cdef WeightsC get_c_weights(model) except *:
|
||||||
cdef SizesC get_c_sizes(model, int batch_size) except *:
|
cdef SizesC get_c_sizes(model, int batch_size) except *:
|
||||||
cdef SizesC output
|
cdef SizesC output
|
||||||
output.states = batch_size
|
output.states = batch_size
|
||||||
output.classes = model.vec2scores.nO
|
if model.vec2scores is None:
|
||||||
|
output.classes = model.state2vec.nO
|
||||||
|
else:
|
||||||
|
output.classes = model.vec2scores.nO
|
||||||
output.hiddens = model.state2vec.nO
|
output.hiddens = model.state2vec.nO
|
||||||
output.pieces = model.state2vec.nP
|
output.pieces = model.state2vec.nP
|
||||||
output.feats = model.state2vec.nF
|
output.feats = model.state2vec.nF
|
||||||
|
@ -90,11 +99,12 @@ cdef void resize_activations(ActivationsC* A, SizesC n) nogil:
|
||||||
|
|
||||||
cdef void predict_states(ActivationsC* A, StateC** states,
|
cdef void predict_states(ActivationsC* A, StateC** states,
|
||||||
const WeightsC* W, SizesC n) nogil:
|
const WeightsC* W, SizesC n) nogil:
|
||||||
|
cdef double one = 1.0
|
||||||
resize_activations(A, n)
|
resize_activations(A, n)
|
||||||
memset(A.unmaxed, 0, n.states * n.hiddens * n.pieces * sizeof(float))
|
|
||||||
memset(A.hiddens, 0, n.states * n.hiddens * sizeof(float))
|
|
||||||
for i in range(n.states):
|
for i in range(n.states):
|
||||||
states[i].set_context_tokens(&A.token_ids[i*n.feats], n.feats)
|
states[i].set_context_tokens(&A.token_ids[i*n.feats], n.feats)
|
||||||
|
memset(A.unmaxed, 0, n.states * n.hiddens * n.pieces * sizeof(float))
|
||||||
|
memset(A.hiddens, 0, n.states * n.hiddens * sizeof(float))
|
||||||
sum_state_features(A.unmaxed,
|
sum_state_features(A.unmaxed,
|
||||||
W.feat_weights, A.token_ids, n.states, n.feats, n.hiddens * n.pieces)
|
W.feat_weights, A.token_ids, n.states, n.feats, n.hiddens * n.pieces)
|
||||||
for i in range(n.states):
|
for i in range(n.states):
|
||||||
|
@ -105,18 +115,20 @@ cdef void predict_states(ActivationsC* A, StateC** states,
|
||||||
which = Vec.arg_max(&A.unmaxed[index], n.pieces)
|
which = Vec.arg_max(&A.unmaxed[index], n.pieces)
|
||||||
A.hiddens[i*n.hiddens + j] = A.unmaxed[index + which]
|
A.hiddens[i*n.hiddens + j] = A.unmaxed[index + which]
|
||||||
memset(A.scores, 0, n.states * n.classes * sizeof(float))
|
memset(A.scores, 0, n.states * n.classes * sizeof(float))
|
||||||
cdef double one = 1.0
|
if W.hidden_weights == NULL:
|
||||||
# Compute hidden-to-output
|
memcpy(A.scores, A.hiddens, n.states * n.classes * sizeof(float))
|
||||||
blis.cy.gemm(blis.cy.NO_TRANSPOSE, blis.cy.TRANSPOSE,
|
else:
|
||||||
n.states, n.classes, n.hiddens, one,
|
# Compute hidden-to-output
|
||||||
<float*>A.hiddens, n.hiddens, 1,
|
blis.cy.gemm(blis.cy.NO_TRANSPOSE, blis.cy.TRANSPOSE,
|
||||||
<float*>W.hidden_weights, n.hiddens, 1,
|
n.states, n.classes, n.hiddens, one,
|
||||||
one,
|
<float*>A.hiddens, n.hiddens, 1,
|
||||||
<float*>A.scores, n.classes, 1)
|
<float*>W.hidden_weights, n.hiddens, 1,
|
||||||
# Add bias
|
one,
|
||||||
for i in range(n.states):
|
<float*>A.scores, n.classes, 1)
|
||||||
VecVec.add_i(&A.scores[i*n.classes],
|
# Add bias
|
||||||
W.hidden_bias, 1., n.classes)
|
for i in range(n.states):
|
||||||
|
VecVec.add_i(&A.scores[i*n.classes],
|
||||||
|
W.hidden_bias, 1., n.classes)
|
||||||
# Set unseen classes to minimum value
|
# Set unseen classes to minimum value
|
||||||
i = 0
|
i = 0
|
||||||
min_ = A.scores[0]
|
min_ = A.scores[0]
|
||||||
|
@ -204,7 +216,9 @@ cdef int arg_max_if_valid(const weight_t* scores, const int* is_valid, int n) no
|
||||||
class ParserModel(Model):
|
class ParserModel(Model):
|
||||||
def __init__(self, tok2vec, lower_model, upper_model, unseen_classes=None):
|
def __init__(self, tok2vec, lower_model, upper_model, unseen_classes=None):
|
||||||
Model.__init__(self)
|
Model.__init__(self)
|
||||||
self._layers = [tok2vec, lower_model, upper_model]
|
self._layers = [tok2vec, lower_model]
|
||||||
|
if upper_model is not None:
|
||||||
|
self._layers.append(upper_model)
|
||||||
self.unseen_classes = set()
|
self.unseen_classes = set()
|
||||||
if unseen_classes:
|
if unseen_classes:
|
||||||
for class_ in unseen_classes:
|
for class_ in unseen_classes:
|
||||||
|
@ -219,6 +233,8 @@ class ParserModel(Model):
|
||||||
return step_model, finish_parser_update
|
return step_model, finish_parser_update
|
||||||
|
|
||||||
def resize_output(self, new_output):
|
def resize_output(self, new_output):
|
||||||
|
if len(self._layers) == 2:
|
||||||
|
return
|
||||||
if new_output == self.upper.nO:
|
if new_output == self.upper.nO:
|
||||||
return
|
return
|
||||||
smaller = self.upper
|
smaller = self.upper
|
||||||
|
@ -260,12 +276,24 @@ class ParserModel(Model):
|
||||||
class ParserStepModel(Model):
|
class ParserStepModel(Model):
|
||||||
def __init__(self, docs, layers, unseen_classes=None, drop=0.):
|
def __init__(self, docs, layers, unseen_classes=None, drop=0.):
|
||||||
self.tokvecs, self.bp_tokvecs = layers[0].begin_update(docs, drop=drop)
|
self.tokvecs, self.bp_tokvecs = layers[0].begin_update(docs, drop=drop)
|
||||||
|
if layers[1].nP >= 2:
|
||||||
|
activation = "maxout"
|
||||||
|
elif len(layers) == 2:
|
||||||
|
activation = None
|
||||||
|
else:
|
||||||
|
activation = "relu"
|
||||||
self.state2vec = precompute_hiddens(len(docs), self.tokvecs, layers[1],
|
self.state2vec = precompute_hiddens(len(docs), self.tokvecs, layers[1],
|
||||||
drop=drop)
|
activation=activation, drop=drop)
|
||||||
self.vec2scores = layers[-1]
|
if len(layers) == 3:
|
||||||
|
self.vec2scores = layers[-1]
|
||||||
|
else:
|
||||||
|
self.vec2scores = None
|
||||||
self.cuda_stream = util.get_cuda_stream(non_blocking=True)
|
self.cuda_stream = util.get_cuda_stream(non_blocking=True)
|
||||||
self.backprops = []
|
self.backprops = []
|
||||||
self._class_mask = numpy.zeros((self.vec2scores.nO,), dtype='f')
|
if self.vec2scores is None:
|
||||||
|
self._class_mask = numpy.zeros((self.state2vec.nO,), dtype='f')
|
||||||
|
else:
|
||||||
|
self._class_mask = numpy.zeros((self.vec2scores.nO,), dtype='f')
|
||||||
self._class_mask.fill(1)
|
self._class_mask.fill(1)
|
||||||
if unseen_classes is not None:
|
if unseen_classes is not None:
|
||||||
for class_ in unseen_classes:
|
for class_ in unseen_classes:
|
||||||
|
@ -287,10 +315,15 @@ class ParserStepModel(Model):
|
||||||
def begin_update(self, states, drop=0.):
|
def begin_update(self, states, drop=0.):
|
||||||
token_ids = self.get_token_ids(states)
|
token_ids = self.get_token_ids(states)
|
||||||
vector, get_d_tokvecs = self.state2vec.begin_update(token_ids, drop=0.0)
|
vector, get_d_tokvecs = self.state2vec.begin_update(token_ids, drop=0.0)
|
||||||
mask = self.vec2scores.ops.get_dropout_mask(vector.shape, drop)
|
if self.vec2scores is not None:
|
||||||
if mask is not None:
|
mask = self.vec2scores.ops.get_dropout_mask(vector.shape, drop)
|
||||||
vector *= mask
|
if mask is not None:
|
||||||
scores, get_d_vector = self.vec2scores.begin_update(vector, drop=drop)
|
vector *= mask
|
||||||
|
scores, get_d_vector = self.vec2scores.begin_update(vector, drop=drop)
|
||||||
|
else:
|
||||||
|
scores = NumpyOps().asarray(vector)
|
||||||
|
get_d_vector = lambda d_scores, sgd=None: d_scores
|
||||||
|
mask = None
|
||||||
# If the class is unseen, make sure its score is minimum
|
# If the class is unseen, make sure its score is minimum
|
||||||
scores[:, self._class_mask == 0] = numpy.nanmin(scores)
|
scores[:, self._class_mask == 0] = numpy.nanmin(scores)
|
||||||
|
|
||||||
|
@ -327,12 +360,12 @@ class ParserStepModel(Model):
|
||||||
return ids
|
return ids
|
||||||
|
|
||||||
def make_updates(self, sgd):
|
def make_updates(self, sgd):
|
||||||
# Tells CUDA to block, so our async copies complete.
|
|
||||||
if self.cuda_stream is not None:
|
|
||||||
self.cuda_stream.synchronize()
|
|
||||||
# Add a padding vector to the d_tokvecs gradient, so that missing
|
# Add a padding vector to the d_tokvecs gradient, so that missing
|
||||||
# values don't affect the real gradient.
|
# values don't affect the real gradient.
|
||||||
d_tokvecs = self.ops.allocate((self.tokvecs.shape[0]+1, self.tokvecs.shape[1]))
|
d_tokvecs = self.ops.allocate((self.tokvecs.shape[0]+1, self.tokvecs.shape[1]))
|
||||||
|
# Tells CUDA to block, so our async copies complete.
|
||||||
|
if self.cuda_stream is not None:
|
||||||
|
self.cuda_stream.synchronize()
|
||||||
for ids, d_vector, bp_vector in self.backprops:
|
for ids, d_vector, bp_vector in self.backprops:
|
||||||
d_state_features = bp_vector((d_vector, ids), sgd=sgd)
|
d_state_features = bp_vector((d_vector, ids), sgd=sgd)
|
||||||
ids = ids.flatten()
|
ids = ids.flatten()
|
||||||
|
@ -370,9 +403,10 @@ cdef class precompute_hiddens:
|
||||||
cdef np.ndarray bias
|
cdef np.ndarray bias
|
||||||
cdef object _cuda_stream
|
cdef object _cuda_stream
|
||||||
cdef object _bp_hiddens
|
cdef object _bp_hiddens
|
||||||
|
cdef object activation
|
||||||
|
|
||||||
def __init__(self, batch_size, tokvecs, lower_model, cuda_stream=None,
|
def __init__(self, batch_size, tokvecs, lower_model, cuda_stream=None,
|
||||||
drop=0.):
|
activation="maxout", drop=0.):
|
||||||
gpu_cached, bp_features = lower_model.begin_update(tokvecs, drop=drop)
|
gpu_cached, bp_features = lower_model.begin_update(tokvecs, drop=drop)
|
||||||
cdef np.ndarray cached
|
cdef np.ndarray cached
|
||||||
if not isinstance(gpu_cached, numpy.ndarray):
|
if not isinstance(gpu_cached, numpy.ndarray):
|
||||||
|
@ -390,6 +424,8 @@ cdef class precompute_hiddens:
|
||||||
self.nP = getattr(lower_model, 'nP', 1)
|
self.nP = getattr(lower_model, 'nP', 1)
|
||||||
self.nO = cached.shape[2]
|
self.nO = cached.shape[2]
|
||||||
self.ops = lower_model.ops
|
self.ops = lower_model.ops
|
||||||
|
assert activation in (None, "relu", "maxout")
|
||||||
|
self.activation = activation
|
||||||
self._is_synchronized = False
|
self._is_synchronized = False
|
||||||
self._cuda_stream = cuda_stream
|
self._cuda_stream = cuda_stream
|
||||||
self._cached = cached
|
self._cached = cached
|
||||||
|
@ -402,7 +438,7 @@ cdef class precompute_hiddens:
|
||||||
return <float*>self._cached.data
|
return <float*>self._cached.data
|
||||||
|
|
||||||
def __call__(self, X):
|
def __call__(self, X):
|
||||||
return self.begin_update(X)[0]
|
return self.begin_update(X, drop=None)[0]
|
||||||
|
|
||||||
def begin_update(self, token_ids, drop=0.):
|
def begin_update(self, token_ids, drop=0.):
|
||||||
cdef np.ndarray state_vector = numpy.zeros(
|
cdef np.ndarray state_vector = numpy.zeros(
|
||||||
|
@ -435,28 +471,31 @@ cdef class precompute_hiddens:
|
||||||
else:
|
else:
|
||||||
ops = CupyOps()
|
ops = CupyOps()
|
||||||
|
|
||||||
if self.nP == 1:
|
if self.activation == "maxout":
|
||||||
state_vector = state_vector.reshape(state_vector.shape[:-1])
|
|
||||||
mask = state_vector >= 0.
|
|
||||||
state_vector *= mask
|
|
||||||
else:
|
|
||||||
state_vector, mask = ops.maxout(state_vector)
|
state_vector, mask = ops.maxout(state_vector)
|
||||||
|
else:
|
||||||
|
state_vector = state_vector.reshape(state_vector.shape[:-1])
|
||||||
|
if self.activation == "relu":
|
||||||
|
mask = state_vector >= 0.
|
||||||
|
state_vector *= mask
|
||||||
|
|
||||||
def backprop_nonlinearity(d_best, sgd=None):
|
def backprop_nonlinearity(d_best, sgd=None):
|
||||||
if isinstance(d_best, numpy.ndarray):
|
if isinstance(d_best, numpy.ndarray):
|
||||||
ops = NumpyOps()
|
ops = NumpyOps()
|
||||||
else:
|
else:
|
||||||
ops = CupyOps()
|
ops = CupyOps()
|
||||||
mask_ = ops.asarray(mask)
|
|
||||||
|
|
||||||
# This will usually be on GPU
|
# This will usually be on GPU
|
||||||
d_best = ops.asarray(d_best)
|
d_best = ops.asarray(d_best)
|
||||||
# Fix nans (which can occur from unseen classes.)
|
# Fix nans (which can occur from unseen classes.)
|
||||||
d_best[ops.xp.isnan(d_best)] = 0.
|
d_best[ops.xp.isnan(d_best)] = 0.
|
||||||
if self.nP == 1:
|
if self.activation == "maxout":
|
||||||
|
mask_ = ops.asarray(mask)
|
||||||
|
return ops.backprop_maxout(d_best, mask_, self.nP)
|
||||||
|
elif self.activation == "relu":
|
||||||
|
mask_ = ops.asarray(mask)
|
||||||
d_best *= mask_
|
d_best *= mask_
|
||||||
d_best = d_best.reshape((d_best.shape + (1,)))
|
d_best = d_best.reshape((d_best.shape + (1,)))
|
||||||
return d_best
|
return d_best
|
||||||
else:
|
else:
|
||||||
return ops.backprop_maxout(d_best, mask_, self.nP)
|
return d_best.reshape((d_best.shape + (1,)))
|
||||||
return state_vector, backprop_nonlinearity
|
return state_vector, backprop_nonlinearity
|
||||||
|
|
|
@ -22,7 +22,7 @@ from thinc.extra.search cimport Beam
|
||||||
from thinc.api import chain, clone
|
from thinc.api import chain, clone
|
||||||
from thinc.v2v import Model, Maxout, Affine
|
from thinc.v2v import Model, Maxout, Affine
|
||||||
from thinc.misc import LayerNorm
|
from thinc.misc import LayerNorm
|
||||||
from thinc.neural.ops import CupyOps
|
from thinc.neural.ops import NumpyOps, CupyOps
|
||||||
from thinc.neural.util import get_array_module
|
from thinc.neural.util import get_array_module
|
||||||
from thinc.linalg cimport Vec, VecVec
|
from thinc.linalg cimport Vec, VecVec
|
||||||
import srsly
|
import srsly
|
||||||
|
@ -56,19 +56,24 @@ cdef class Parser:
|
||||||
subword_features = util.env_opt('subword_features',
|
subword_features = util.env_opt('subword_features',
|
||||||
cfg.get('subword_features', True))
|
cfg.get('subword_features', True))
|
||||||
conv_depth = util.env_opt('conv_depth', cfg.get('conv_depth', 4))
|
conv_depth = util.env_opt('conv_depth', cfg.get('conv_depth', 4))
|
||||||
|
conv_window = util.env_opt('conv_window', cfg.get('conv_depth', 1))
|
||||||
bilstm_depth = util.env_opt('bilstm_depth', cfg.get('bilstm_depth', 0))
|
bilstm_depth = util.env_opt('bilstm_depth', cfg.get('bilstm_depth', 0))
|
||||||
self_attn_depth = util.env_opt('self_attn_depth', cfg.get('self_attn_depth', 0))
|
self_attn_depth = util.env_opt('self_attn_depth', cfg.get('self_attn_depth', 0))
|
||||||
if depth != 1:
|
if depth not in (0, 1):
|
||||||
raise ValueError(TempErrors.T004.format(value=depth))
|
raise ValueError(TempErrors.T004.format(value=depth))
|
||||||
parser_maxout_pieces = util.env_opt('parser_maxout_pieces',
|
parser_maxout_pieces = util.env_opt('parser_maxout_pieces',
|
||||||
cfg.get('maxout_pieces', 2))
|
cfg.get('maxout_pieces', 2))
|
||||||
token_vector_width = util.env_opt('token_vector_width',
|
token_vector_width = util.env_opt('token_vector_width',
|
||||||
cfg.get('token_vector_width', 96))
|
cfg.get('token_vector_width', 96))
|
||||||
hidden_width = util.env_opt('hidden_width', cfg.get('hidden_width', 64))
|
hidden_width = util.env_opt('hidden_width', cfg.get('hidden_width', 64))
|
||||||
|
if depth == 0:
|
||||||
|
hidden_width = nr_class
|
||||||
|
parser_maxout_pieces = 1
|
||||||
embed_size = util.env_opt('embed_size', cfg.get('embed_size', 2000))
|
embed_size = util.env_opt('embed_size', cfg.get('embed_size', 2000))
|
||||||
pretrained_vectors = cfg.get('pretrained_vectors', None)
|
pretrained_vectors = cfg.get('pretrained_vectors', None)
|
||||||
tok2vec = Tok2Vec(token_vector_width, embed_size,
|
tok2vec = Tok2Vec(token_vector_width, embed_size,
|
||||||
conv_depth=conv_depth,
|
conv_depth=conv_depth,
|
||||||
|
conv_window=conv_window,
|
||||||
subword_features=subword_features,
|
subword_features=subword_features,
|
||||||
pretrained_vectors=pretrained_vectors,
|
pretrained_vectors=pretrained_vectors,
|
||||||
bilstm_depth=bilstm_depth,
|
bilstm_depth=bilstm_depth,
|
||||||
|
@ -79,10 +84,12 @@ cdef class Parser:
|
||||||
nF=cls.nr_feature, nI=token_vector_width,
|
nF=cls.nr_feature, nI=token_vector_width,
|
||||||
nP=parser_maxout_pieces)
|
nP=parser_maxout_pieces)
|
||||||
lower.nP = parser_maxout_pieces
|
lower.nP = parser_maxout_pieces
|
||||||
|
if depth == 1:
|
||||||
with Model.use_device('cpu'):
|
with Model.use_device('cpu'):
|
||||||
upper = Affine(nr_class, hidden_width, drop_factor=0.0)
|
upper = Affine(nr_class, hidden_width, drop_factor=0.0)
|
||||||
upper.W *= 0
|
upper.W *= 0
|
||||||
|
else:
|
||||||
|
upper = None
|
||||||
|
|
||||||
cfg = {
|
cfg = {
|
||||||
'nr_class': nr_class,
|
'nr_class': nr_class,
|
||||||
|
@ -94,6 +101,7 @@ cdef class Parser:
|
||||||
'bilstm_depth': bilstm_depth,
|
'bilstm_depth': bilstm_depth,
|
||||||
'self_attn_depth': self_attn_depth,
|
'self_attn_depth': self_attn_depth,
|
||||||
'conv_depth': conv_depth,
|
'conv_depth': conv_depth,
|
||||||
|
'conv_window': conv_window,
|
||||||
'embed_size': embed_size
|
'embed_size': embed_size
|
||||||
}
|
}
|
||||||
return ParserModel(tok2vec, lower, upper), cfg
|
return ParserModel(tok2vec, lower, upper), cfg
|
||||||
|
|
Loading…
Reference in New Issue
Block a user