mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-04 12:20:20 +03:00
Rename some identifiers in the parser refactor (#10935)
* Rename _parseC to _parse_batch * tb_framework: prefix many auxiliary functions with underscore To clearly state the intent that they are private. * Rename `lower` to `hidden`, `upper` to `output`
This commit is contained in:
parent
63e90dd6a1
commit
bc36c71982
|
@ -45,28 +45,28 @@ def TransitionModel(
|
|||
tok2vec_projected = chain(tok2vec, list2array(), Linear(hidden_width, t2v_width)) # type: ignore
|
||||
tok2vec_projected.set_dim("nO", hidden_width)
|
||||
|
||||
# FIXME: we use `upper` as a container for the upper layer's
|
||||
# FIXME: we use `output` as a container for the output layer's
|
||||
# weights and biases. Thinc optimizers cannot handle resizing
|
||||
# of parameters. So, when the parser model is resized, we
|
||||
# construct a new `upper` layer, which has a different key in
|
||||
# construct a new `output` layer, which has a different key in
|
||||
# the optimizer. Once the optimizer supports parameter resizing,
|
||||
# we can replace the `upper` layer by `upper_W` and `upper_b`
|
||||
# we can replace the `output` layer by `output_W` and `output_b`
|
||||
# parameters in this model.
|
||||
upper = Linear(nO=None, nI=hidden_width, init_W=zero_init)
|
||||
output = Linear(nO=None, nI=hidden_width, init_W=zero_init)
|
||||
|
||||
return Model(
|
||||
name="parser_model",
|
||||
forward=forward,
|
||||
init=init,
|
||||
layers=[tok2vec_projected, upper],
|
||||
layers=[tok2vec_projected, output],
|
||||
refs={
|
||||
"tok2vec": tok2vec_projected,
|
||||
"upper": upper,
|
||||
"output": output,
|
||||
},
|
||||
params={
|
||||
"lower_W": None, # Floats2d W for the hidden layer
|
||||
"lower_b": None, # Floats1d bias for the hidden layer
|
||||
"lower_pad": None, # Floats1d padding for the hidden layer
|
||||
"hidden_W": None, # Floats2d W for the hidden layer
|
||||
"hidden_b": None, # Floats1d bias for the hidden layer
|
||||
"hidden_pad": None, # Floats1d padding for the hidden layer
|
||||
},
|
||||
dims={
|
||||
"nO": None, # Output size
|
||||
|
@ -86,28 +86,28 @@ def TransitionModel(
|
|||
|
||||
def resize_output(model: Model, new_nO: int) -> Model:
|
||||
old_nO = model.maybe_get_dim("nO")
|
||||
upper = model.get_ref("upper")
|
||||
output = model.get_ref("output")
|
||||
if old_nO is None:
|
||||
model.set_dim("nO", new_nO)
|
||||
upper.set_dim("nO", new_nO)
|
||||
upper.initialize()
|
||||
output.set_dim("nO", new_nO)
|
||||
output.initialize()
|
||||
return model
|
||||
elif new_nO <= old_nO:
|
||||
return model
|
||||
elif upper.has_param("W"):
|
||||
elif output.has_param("W"):
|
||||
nH = model.get_dim("nH")
|
||||
new_upper = Linear(nO=new_nO, nI=nH, init_W=zero_init)
|
||||
new_upper.initialize()
|
||||
new_W = new_upper.get_param("W")
|
||||
new_b = new_upper.get_param("b")
|
||||
old_W = upper.get_param("W")
|
||||
old_b = upper.get_param("b")
|
||||
new_output = Linear(nO=new_nO, nI=nH, init_W=zero_init)
|
||||
new_output.initialize()
|
||||
new_W = new_output.get_param("W")
|
||||
new_b = new_output.get_param("b")
|
||||
old_W = output.get_param("W")
|
||||
old_b = output.get_param("b")
|
||||
new_W[:old_nO] = old_W # type: ignore
|
||||
new_b[:old_nO] = old_b # type: ignore
|
||||
for i in range(old_nO, new_nO):
|
||||
model.attrs["unseen_classes"].add(i)
|
||||
model.layers[-1] = new_upper
|
||||
model.set_ref("upper", new_upper)
|
||||
model.layers[-1] = new_output
|
||||
model.set_ref("output", new_output)
|
||||
# TODO: Avoid this private intrusion
|
||||
model._dims["nO"] = new_nO
|
||||
return model
|
||||
|
@ -141,10 +141,10 @@ def init(
|
|||
# Wl = zero_init(ops, Wl.shape)
|
||||
Wl = glorot_uniform_init(ops, Wl.shape)
|
||||
padl = uniform_init(ops, padl.shape) # type: ignore
|
||||
# TODO: Experiment with whether better to initialize upper_W
|
||||
model.set_param("lower_W", Wl)
|
||||
model.set_param("lower_b", bl)
|
||||
model.set_param("lower_pad", padl)
|
||||
# TODO: Experiment with whether better to initialize output_W
|
||||
model.set_param("hidden_W", Wl)
|
||||
model.set_param("hidden_b", bl)
|
||||
model.set_param("hidden_pad", padl)
|
||||
# model = _lsuv_init(model)
|
||||
return model
|
||||
|
||||
|
@ -160,12 +160,12 @@ def forward(model, docs_moves: InT, is_train: bool):
|
|||
docs, moves, actions = docs_moves
|
||||
|
||||
beam_width = model.attrs["beam_width"]
|
||||
lower_pad = model.get_param("lower_pad")
|
||||
hidden_pad = model.get_param("hidden_pad")
|
||||
tok2vec = model.get_ref("tok2vec")
|
||||
|
||||
states = moves.init_batch(docs)
|
||||
tokvecs, backprop_tok2vec = tok2vec(docs, is_train)
|
||||
tokvecs = model.ops.xp.vstack((tokvecs, lower_pad))
|
||||
tokvecs = model.ops.xp.vstack((tokvecs, hidden_pad))
|
||||
feats, backprop_feats = _forward_precomputable_affine(model, tokvecs, is_train)
|
||||
seen_mask = _get_seen_mask(model)
|
||||
|
||||
|
@ -183,23 +183,23 @@ def _forward_greedy_cpu(model: Model, TransitionSystem moves, states: List[State
|
|||
for state in states:
|
||||
if not state.is_final():
|
||||
c_states.push_back(state.c)
|
||||
weights = get_c_weights(model, <float*>feats.data, seen_mask)
|
||||
weights = _get_c_weights(model, <float*>feats.data, seen_mask)
|
||||
# Precomputed features have rows for each token, plus one for padding.
|
||||
cdef int n_tokens = feats.shape[0] - 1
|
||||
sizes = get_c_sizes(model, c_states.size(), n_tokens)
|
||||
sizes = _get_c_sizes(model, c_states.size(), n_tokens)
|
||||
cdef CBlas cblas = model.ops.cblas()
|
||||
scores = _parseC(cblas, moves, &c_states[0], weights, sizes, actions=actions)
|
||||
scores = _parse_batch(cblas, moves, &c_states[0], weights, sizes, actions=actions)
|
||||
|
||||
def backprop(dY):
|
||||
raise ValueError(Errors.E1038)
|
||||
|
||||
return (states, scores), backprop
|
||||
|
||||
cdef list _parseC(CBlas cblas, TransitionSystem moves, StateC** states,
|
||||
WeightsC weights, SizesC sizes, actions: Optional[List[Ints1d]]=None):
|
||||
cdef list _parse_batch(CBlas cblas, TransitionSystem moves, StateC** states,
|
||||
WeightsC weights, SizesC sizes, actions: Optional[List[Ints1d]]=None):
|
||||
cdef int i, j
|
||||
cdef vector[StateC *] unfinished
|
||||
cdef ActivationsC activations = alloc_activations(sizes)
|
||||
cdef ActivationsC activations = _alloc_activations(sizes)
|
||||
cdef np.ndarray step_scores
|
||||
cdef np.ndarray step_actions
|
||||
|
||||
|
@ -208,7 +208,7 @@ cdef list _parseC(CBlas cblas, TransitionSystem moves, StateC** states,
|
|||
step_scores = numpy.empty((sizes.states, sizes.classes), dtype="f")
|
||||
step_actions = actions[0] if actions is not None else None
|
||||
with nogil:
|
||||
predict_states(cblas, &activations, <float*>step_scores.data, states, &weights, sizes)
|
||||
_predict_states(cblas, &activations, <float*>step_scores.data, states, &weights, sizes)
|
||||
if actions is None:
|
||||
# Validate actions, argmax, take action.
|
||||
c_transition_batch(moves, states, <const float*>step_scores.data, sizes.classes,
|
||||
|
@ -224,7 +224,7 @@ cdef list _parseC(CBlas cblas, TransitionSystem moves, StateC** states,
|
|||
scores.append(step_scores)
|
||||
unfinished.clear()
|
||||
actions = actions[1:] if actions is not None else None
|
||||
free_activations(&activations)
|
||||
_free_activations(&activations)
|
||||
|
||||
return scores
|
||||
|
||||
|
@ -232,8 +232,8 @@ cdef list _parseC(CBlas cblas, TransitionSystem moves, StateC** states,
|
|||
def _forward_fallback(model: Model, moves: TransitionSystem, states: List[StateClass], tokvecs, backprop_tok2vec, feats, backprop_feats, seen_mask, is_train: bool,
|
||||
actions: Optional[List[Ints1d]]=None):
|
||||
nF = model.get_dim("nF")
|
||||
upper = model.get_ref("upper")
|
||||
lower_b = model.get_param("lower_b")
|
||||
output = model.get_ref("output")
|
||||
hidden_b = model.get_param("hidden_b")
|
||||
nH = model.get_dim("nH")
|
||||
nP = model.get_dim("nP")
|
||||
|
||||
|
@ -260,13 +260,13 @@ def _forward_fallback(model: Model, moves: TransitionSystem, states: List[StateC
|
|||
# Sum the state features, add the bias and apply the activation (maxout)
|
||||
# to create the state vectors.
|
||||
preacts2f = feats[ids, arange].sum(axis=1) # type: ignore
|
||||
preacts2f += lower_b
|
||||
preacts2f += hidden_b
|
||||
preacts = ops.reshape3f(preacts2f, preacts2f.shape[0], nH, nP)
|
||||
assert preacts.shape[0] == len(batch.get_unfinished_states()), preacts.shape
|
||||
statevecs, which = ops.maxout(preacts)
|
||||
# We don't use upper's backprop, since we want to backprop for
|
||||
# We don't use output's backprop, since we want to backprop for
|
||||
# all states at once, rather than a single state.
|
||||
scores = upper.predict(statevecs)
|
||||
scores = output.predict(statevecs)
|
||||
scores[:, seen_mask] = ops.xp.nanmin(scores)
|
||||
# Transition the states, filtering out any that are finished.
|
||||
cpu_scores = ops.to_numpy(scores)
|
||||
|
@ -295,23 +295,23 @@ def _forward_fallback(model: Model, moves: TransitionSystem, states: List[StateC
|
|||
if (d_scores[:, clas] < 0).any():
|
||||
model.attrs["unseen_classes"].remove(clas)
|
||||
d_scores *= seen_mask == False
|
||||
# Calculate the gradients for the parameters of the upper layer.
|
||||
# Calculate the gradients for the parameters of the output layer.
|
||||
# The weight gemm is (nS, nO) @ (nS, nH).T
|
||||
upper.inc_grad("b", d_scores.sum(axis=0))
|
||||
upper.inc_grad("W", ops.gemm(d_scores, statevecs, trans1=True))
|
||||
# Now calculate d_statevecs, by backproping through the upper linear layer.
|
||||
output.inc_grad("b", d_scores.sum(axis=0))
|
||||
output.inc_grad("W", ops.gemm(d_scores, statevecs, trans1=True))
|
||||
# Now calculate d_statevecs, by backproping through the output linear layer.
|
||||
# This gemm is (nS, nO) @ (nO, nH)
|
||||
upper_W = upper.get_param("W")
|
||||
d_statevecs = ops.gemm(d_scores, upper_W)
|
||||
output_W = output.get_param("W")
|
||||
d_statevecs = ops.gemm(d_scores, output_W)
|
||||
# Backprop through the maxout activation
|
||||
d_preacts = ops.backprop_maxout(d_statevecs, which, nP)
|
||||
d_preacts2f = ops.reshape2f(d_preacts, d_preacts.shape[0], nH * nP)
|
||||
model.inc_grad("lower_b", d_preacts2f.sum(axis=0))
|
||||
model.inc_grad("hidden_b", d_preacts2f.sum(axis=0))
|
||||
# We don't need to backprop the summation, because we pass back the IDs instead
|
||||
d_state_features = backprop_feats((d_preacts2f, ids))
|
||||
d_tokvecs = ops.alloc2f(tokvecs.shape[0], tokvecs.shape[1])
|
||||
ops.scatter_add(d_tokvecs, ids, d_state_features)
|
||||
model.inc_grad("lower_pad", d_tokvecs[-1])
|
||||
model.inc_grad("hidden_pad", d_tokvecs[-1])
|
||||
return (backprop_tok2vec(d_tokvecs[:-1]), None)
|
||||
|
||||
return (list(batch), all_scores), backprop_parser
|
||||
|
@ -323,10 +323,10 @@ def _forward_reference(
|
|||
"""Slow reference implementation, without the precomputation"""
|
||||
nF = model.get_dim("nF")
|
||||
tok2vec = model.get_ref("tok2vec")
|
||||
upper = model.get_ref("upper")
|
||||
lower_pad = model.get_param("lower_pad")
|
||||
lower_W = model.get_param("lower_W")
|
||||
lower_b = model.get_param("lower_b")
|
||||
output = model.get_ref("output")
|
||||
hidden_pad = model.get_param("hidden_pad")
|
||||
hidden_W = model.get_param("hidden_W")
|
||||
hidden_b = model.get_param("hidden_b")
|
||||
nH = model.get_dim("nH")
|
||||
nP = model.get_dim("nP")
|
||||
nO = model.get_dim("nO")
|
||||
|
@ -336,7 +336,7 @@ def _forward_reference(
|
|||
docs, moves = docs_moves
|
||||
states = moves.init_batch(docs)
|
||||
tokvecs, backprop_tok2vec = tok2vec(docs, is_train)
|
||||
tokvecs = model.ops.xp.vstack((tokvecs, lower_pad))
|
||||
tokvecs = model.ops.xp.vstack((tokvecs, hidden_pad))
|
||||
all_ids = []
|
||||
all_which = []
|
||||
all_statevecs = []
|
||||
|
@ -353,13 +353,13 @@ def _forward_reference(
|
|||
# to create the state vectors.
|
||||
tokfeats3f = tokvecs[ids]
|
||||
tokfeats = model.ops.reshape2f(tokfeats3f, tokfeats3f.shape[0], -1)
|
||||
preacts2f = model.ops.gemm(tokfeats, lower_W, trans2=True)
|
||||
preacts2f += lower_b
|
||||
preacts2f = model.ops.gemm(tokfeats, hidden_W, trans2=True)
|
||||
preacts2f += hidden_b
|
||||
preacts = model.ops.reshape3f(preacts2f, preacts2f.shape[0], nH, nP)
|
||||
statevecs, which = ops.maxout(preacts)
|
||||
# We don't use upper's backprop, since we want to backprop for
|
||||
# We don't use output's backprop, since we want to backprop for
|
||||
# all states at once, rather than a single state.
|
||||
scores = upper.predict(statevecs)
|
||||
scores = output.predict(statevecs)
|
||||
scores[:, seen_mask] = model.ops.xp.nanmin(scores)
|
||||
# Transition the states, filtering out any that are finished.
|
||||
next_states = moves.transition_states(next_states, scores)
|
||||
|
@ -390,28 +390,28 @@ def _forward_reference(
|
|||
d_scores *= seen_mask == False
|
||||
assert statevecs.shape == (nS, nH), statevecs.shape
|
||||
assert d_scores.shape == (nS, nO), d_scores.shape
|
||||
# Calculate the gradients for the parameters of the upper layer.
|
||||
# Calculate the gradients for the parameters of the output layer.
|
||||
# The weight gemm is (nS, nO) @ (nS, nH).T
|
||||
upper.inc_grad("b", d_scores.sum(axis=0))
|
||||
upper.inc_grad("W", model.ops.gemm(d_scores, statevecs, trans1=True))
|
||||
# Now calculate d_statevecs, by backproping through the upper linear layer.
|
||||
output.inc_grad("b", d_scores.sum(axis=0))
|
||||
output.inc_grad("W", model.ops.gemm(d_scores, statevecs, trans1=True))
|
||||
# Now calculate d_statevecs, by backproping through the output linear layer.
|
||||
# This gemm is (nS, nO) @ (nO, nH)
|
||||
upper_W = upper.get_param("W")
|
||||
d_statevecs = model.ops.gemm(d_scores, upper_W)
|
||||
output_W = output.get_param("W")
|
||||
d_statevecs = model.ops.gemm(d_scores, output_W)
|
||||
# Backprop through the maxout activation
|
||||
d_preacts = model.ops.backprop_maxout(d_statevecs, which, nP)
|
||||
d_preacts2f = model.ops.reshape2f(d_preacts, d_preacts.shape[0], nH * nP)
|
||||
# Now increment the gradients for the lower layer.
|
||||
# Now increment the gradients for the hidden layer.
|
||||
# The gemm here is (nS, nH*nP) @ (nS, nF*nI)
|
||||
model.inc_grad("lower_b", d_preacts2f.sum(axis=0))
|
||||
model.inc_grad("lower_W", model.ops.gemm(d_preacts2f, tokfeats, trans1=True))
|
||||
model.inc_grad("hidden_b", d_preacts2f.sum(axis=0))
|
||||
model.inc_grad("hidden_W", model.ops.gemm(d_preacts2f, tokfeats, trans1=True))
|
||||
# Caclulate d_tokfeats
|
||||
# The gemm here is (nS, nH*nP) @ (nH*nP, nF*nI)
|
||||
d_tokfeats = model.ops.gemm(d_preacts2f, lower_W)
|
||||
d_tokfeats = model.ops.gemm(d_preacts2f, hidden_W)
|
||||
# Get the gradients of the tokvecs and the padding
|
||||
d_tokfeats3f = model.ops.reshape3f(d_tokfeats, nS, nF, nI)
|
||||
model.ops.scatter_add(d_tokvecs, ids, d_tokfeats3f)
|
||||
model.inc_grad("lower_pad", d_tokvecs[-1])
|
||||
model.inc_grad("hidden_pad", d_tokvecs[-1])
|
||||
return (backprop_tok2vec(d_tokvecs[:-1]), None)
|
||||
|
||||
return (states, all_scores), backprop_parser
|
||||
|
@ -425,7 +425,7 @@ def _get_seen_mask(model: Model) -> numpy.array[bool, 1]:
|
|||
|
||||
|
||||
def _forward_precomputable_affine(model, X: Floats2d, is_train: bool):
|
||||
W: Floats2d = model.get_param("lower_W")
|
||||
W: Floats2d = model.get_param("hidden_W")
|
||||
nF = model.get_dim("nF")
|
||||
nH = model.get_dim("nH")
|
||||
nP = model.get_dim("nP")
|
||||
|
@ -456,7 +456,7 @@ def _forward_precomputable_affine(model, X: Floats2d, is_train: bool):
|
|||
dXf = model.ops.gemm(dY, W)
|
||||
Xf = X[ids].reshape((ids.shape[0], -1))
|
||||
dW = model.ops.gemm(dY, Xf, trans1=True)
|
||||
model.inc_grad("lower_W", dW)
|
||||
model.inc_grad("hidden_W", dW)
|
||||
return model.ops.reshape3f(dXf, dXf.shape[0], nF, nI)
|
||||
|
||||
return Yf, backward
|
||||
|
@ -482,7 +482,7 @@ def _lsuv_init(model: Model):
|
|||
we set the maxout weights to values that empirically result in
|
||||
whitened outputs given whitened inputs.
|
||||
"""
|
||||
W = model.maybe_get_param("lower_W")
|
||||
W = model.maybe_get_param("hidden_W")
|
||||
if W is not None and W.any():
|
||||
return
|
||||
|
||||
|
@ -523,66 +523,66 @@ def _lsuv_init(model: Model):
|
|||
tol_var = 0.01
|
||||
tol_mean = 0.01
|
||||
t_max = 10
|
||||
W = cast(Floats4d, model.get_param("lower_W").copy())
|
||||
b = cast(Floats2d, model.get_param("lower_b").copy())
|
||||
W = cast(Floats4d, model.get_param("hidden_W").copy())
|
||||
b = cast(Floats2d, model.get_param("hidden_b").copy())
|
||||
for t_i in range(t_max):
|
||||
acts1 = predict(ids, tokvecs)
|
||||
var = model.ops.xp.var(acts1)
|
||||
mean = model.ops.xp.mean(acts1)
|
||||
if abs(var - 1.0) >= tol_var:
|
||||
W /= model.ops.xp.sqrt(var)
|
||||
model.set_param("lower_W", W)
|
||||
model.set_param("hidden_W", W)
|
||||
elif abs(mean) >= tol_mean:
|
||||
b -= mean
|
||||
model.set_param("lower_b", b)
|
||||
model.set_param("hidden_b", b)
|
||||
else:
|
||||
break
|
||||
return model
|
||||
|
||||
|
||||
cdef WeightsC get_c_weights(model, const float* feats, np.ndarray[np.npy_bool, ndim=1] seen_mask) except *:
|
||||
upper = model.get_ref("upper")
|
||||
cdef np.ndarray lower_b = model.get_param("lower_b")
|
||||
cdef np.ndarray upper_W = upper.get_param("W")
|
||||
cdef np.ndarray upper_b = upper.get_param("b")
|
||||
cdef WeightsC _get_c_weights(model, const float* feats, np.ndarray[np.npy_bool, ndim=1] seen_mask) except *:
|
||||
output = model.get_ref("output")
|
||||
cdef np.ndarray hidden_b = model.get_param("hidden_b")
|
||||
cdef np.ndarray output_W = output.get_param("W")
|
||||
cdef np.ndarray output_b = output.get_param("b")
|
||||
|
||||
cdef WeightsC output
|
||||
output.feat_weights = feats
|
||||
output.feat_bias = <const float*>lower_b.data
|
||||
output.hidden_weights = <const float *> upper_W.data
|
||||
output.hidden_bias = <const float *> upper_b.data
|
||||
output.seen_mask = <const int8_t*> seen_mask.data
|
||||
cdef WeightsC weights
|
||||
weights.feat_weights = feats
|
||||
weights.feat_bias = <const float*>hidden_b.data
|
||||
weights.hidden_weights = <const float *> output_W.data
|
||||
weights.hidden_bias = <const float *> output_b.data
|
||||
weights.seen_mask = <const int8_t*> seen_mask.data
|
||||
|
||||
return output
|
||||
return weights
|
||||
|
||||
|
||||
cdef SizesC get_c_sizes(model, int batch_size, int tokens) except *:
|
||||
cdef SizesC output
|
||||
output.states = batch_size
|
||||
output.classes = model.get_dim("nO")
|
||||
output.hiddens = model.get_dim("nH")
|
||||
output.pieces = model.get_dim("nP")
|
||||
output.feats = model.get_dim("nF")
|
||||
output.embed_width = model.get_dim("nI")
|
||||
output.tokens = tokens
|
||||
return output
|
||||
cdef SizesC _get_c_sizes(model, int batch_size, int tokens) except *:
|
||||
cdef SizesC sizes
|
||||
sizes.states = batch_size
|
||||
sizes.classes = model.get_dim("nO")
|
||||
sizes.hiddens = model.get_dim("nH")
|
||||
sizes.pieces = model.get_dim("nP")
|
||||
sizes.feats = model.get_dim("nF")
|
||||
sizes.embed_width = model.get_dim("nI")
|
||||
sizes.tokens = tokens
|
||||
return sizes
|
||||
|
||||
|
||||
cdef ActivationsC alloc_activations(SizesC n) nogil:
|
||||
cdef ActivationsC _alloc_activations(SizesC n) nogil:
|
||||
cdef ActivationsC A
|
||||
memset(&A, 0, sizeof(A))
|
||||
resize_activations(&A, n)
|
||||
_resize_activations(&A, n)
|
||||
return A
|
||||
|
||||
|
||||
cdef void free_activations(const ActivationsC* A) nogil:
|
||||
cdef void _free_activations(const ActivationsC* A) nogil:
|
||||
free(A.token_ids)
|
||||
free(A.unmaxed)
|
||||
free(A.hiddens)
|
||||
free(A.is_valid)
|
||||
|
||||
|
||||
cdef void resize_activations(ActivationsC* A, SizesC n) nogil:
|
||||
cdef void _resize_activations(ActivationsC* A, SizesC n) nogil:
|
||||
if n.states <= A._max_size:
|
||||
A._curr_size = n.states
|
||||
return
|
||||
|
@ -605,12 +605,12 @@ cdef void resize_activations(ActivationsC* A, SizesC n) nogil:
|
|||
A._curr_size = n.states
|
||||
|
||||
|
||||
cdef void predict_states(CBlas cblas, ActivationsC* A, float* scores, StateC** states, const WeightsC* W, SizesC n) nogil:
|
||||
resize_activations(A, n)
|
||||
cdef void _predict_states(CBlas cblas, ActivationsC* A, float* scores, StateC** states, const WeightsC* W, SizesC n) nogil:
|
||||
_resize_activations(A, n)
|
||||
for i in range(n.states):
|
||||
states[i].set_context_tokens(&A.token_ids[i*n.feats], n.feats)
|
||||
memset(A.unmaxed, 0, n.states * n.hiddens * n.pieces * sizeof(float))
|
||||
sum_state_features(cblas, A.unmaxed, W.feat_weights, A.token_ids, n)
|
||||
_sum_state_features(cblas, A.unmaxed, W.feat_weights, A.token_ids, n)
|
||||
for i in range(n.states):
|
||||
VecVec.add_i(&A.unmaxed[i*n.hiddens*n.pieces],
|
||||
W.feat_bias, 1., n.hiddens * n.pieces)
|
||||
|
@ -641,7 +641,7 @@ cdef void predict_states(CBlas cblas, ActivationsC* A, float* scores, StateC** s
|
|||
scores[i*n.classes+j] = min_
|
||||
|
||||
|
||||
cdef void sum_state_features(CBlas cblas, float* output,
|
||||
cdef void _sum_state_features(CBlas cblas, float* output,
|
||||
const float* cached, const int* token_ids, SizesC n) nogil:
|
||||
cdef int idx, b, f, i
|
||||
cdef const float* feature
|
||||
|
|
|
@ -263,12 +263,12 @@ def test_serialize_custom_nlp():
|
|||
nlp2 = spacy.load(d)
|
||||
model = nlp2.get_pipe("parser").model
|
||||
assert model.get_ref("tok2vec") is not None
|
||||
assert model.has_param("lower_W")
|
||||
assert model.has_param("lower_b")
|
||||
upper = model.get_ref("upper")
|
||||
assert upper is not None
|
||||
assert upper.has_param("W")
|
||||
assert upper.has_param("b")
|
||||
assert model.has_param("hidden_W")
|
||||
assert model.has_param("hidden_b")
|
||||
output = model.get_ref("output")
|
||||
assert output is not None
|
||||
assert output.has_param("W")
|
||||
assert output.has_param("b")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("parser_config_string", [parser_config_string_upper])
|
||||
|
@ -285,12 +285,12 @@ def test_serialize_parser(parser_config_string):
|
|||
nlp2 = spacy.load(d)
|
||||
model = nlp2.get_pipe("parser").model
|
||||
assert model.get_ref("tok2vec") is not None
|
||||
assert model.has_param("lower_W")
|
||||
assert model.has_param("lower_b")
|
||||
upper = model.get_ref("upper")
|
||||
assert upper is not None
|
||||
assert upper.has_param("b")
|
||||
assert upper.has_param("W")
|
||||
assert model.has_param("hidden_W")
|
||||
assert model.has_param("hidden_b")
|
||||
output = model.get_ref("output")
|
||||
assert output is not None
|
||||
assert output.has_param("b")
|
||||
assert output.has_param("W")
|
||||
|
||||
|
||||
def test_config_nlp_roundtrip():
|
||||
|
|
Loading…
Reference in New Issue
Block a user