Rename some identifiers in the parser refactor (#10935)

* Rename _parseC to _parse_batch

* tb_framework: prefix many auxiliary functions with underscore

To clearly state the intent that they are private.

* Rename `lower` to `hidden`, `upper` to `output`
This commit is contained in:
Daniël de Kok 2022-06-09 13:45:30 +02:00 committed by GitHub
parent 63e90dd6a1
commit bc36c71982
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 117 additions and 117 deletions

View File

@ -45,28 +45,28 @@ def TransitionModel(
tok2vec_projected = chain(tok2vec, list2array(), Linear(hidden_width, t2v_width)) # type: ignore tok2vec_projected = chain(tok2vec, list2array(), Linear(hidden_width, t2v_width)) # type: ignore
tok2vec_projected.set_dim("nO", hidden_width) tok2vec_projected.set_dim("nO", hidden_width)
# FIXME: we use `upper` as a container for the upper layer's # FIXME: we use `output` as a container for the output layer's
# weights and biases. Thinc optimizers cannot handle resizing # weights and biases. Thinc optimizers cannot handle resizing
# of parameters. So, when the parser model is resized, we # of parameters. So, when the parser model is resized, we
# construct a new `upper` layer, which has a different key in # construct a new `output` layer, which has a different key in
# the optimizer. Once the optimizer supports parameter resizing, # the optimizer. Once the optimizer supports parameter resizing,
# we can replace the `upper` layer by `upper_W` and `upper_b` # we can replace the `output` layer by `output_W` and `output_b`
# parameters in this model. # parameters in this model.
upper = Linear(nO=None, nI=hidden_width, init_W=zero_init) output = Linear(nO=None, nI=hidden_width, init_W=zero_init)
return Model( return Model(
name="parser_model", name="parser_model",
forward=forward, forward=forward,
init=init, init=init,
layers=[tok2vec_projected, upper], layers=[tok2vec_projected, output],
refs={ refs={
"tok2vec": tok2vec_projected, "tok2vec": tok2vec_projected,
"upper": upper, "output": output,
}, },
params={ params={
"lower_W": None, # Floats2d W for the hidden layer "hidden_W": None, # Floats2d W for the hidden layer
"lower_b": None, # Floats1d bias for the hidden layer "hidden_b": None, # Floats1d bias for the hidden layer
"lower_pad": None, # Floats1d padding for the hidden layer "hidden_pad": None, # Floats1d padding for the hidden layer
}, },
dims={ dims={
"nO": None, # Output size "nO": None, # Output size
@ -86,28 +86,28 @@ def TransitionModel(
def resize_output(model: Model, new_nO: int) -> Model: def resize_output(model: Model, new_nO: int) -> Model:
old_nO = model.maybe_get_dim("nO") old_nO = model.maybe_get_dim("nO")
upper = model.get_ref("upper") output = model.get_ref("output")
if old_nO is None: if old_nO is None:
model.set_dim("nO", new_nO) model.set_dim("nO", new_nO)
upper.set_dim("nO", new_nO) output.set_dim("nO", new_nO)
upper.initialize() output.initialize()
return model return model
elif new_nO <= old_nO: elif new_nO <= old_nO:
return model return model
elif upper.has_param("W"): elif output.has_param("W"):
nH = model.get_dim("nH") nH = model.get_dim("nH")
new_upper = Linear(nO=new_nO, nI=nH, init_W=zero_init) new_output = Linear(nO=new_nO, nI=nH, init_W=zero_init)
new_upper.initialize() new_output.initialize()
new_W = new_upper.get_param("W") new_W = new_output.get_param("W")
new_b = new_upper.get_param("b") new_b = new_output.get_param("b")
old_W = upper.get_param("W") old_W = output.get_param("W")
old_b = upper.get_param("b") old_b = output.get_param("b")
new_W[:old_nO] = old_W # type: ignore new_W[:old_nO] = old_W # type: ignore
new_b[:old_nO] = old_b # type: ignore new_b[:old_nO] = old_b # type: ignore
for i in range(old_nO, new_nO): for i in range(old_nO, new_nO):
model.attrs["unseen_classes"].add(i) model.attrs["unseen_classes"].add(i)
model.layers[-1] = new_upper model.layers[-1] = new_output
model.set_ref("upper", new_upper) model.set_ref("output", new_output)
# TODO: Avoid this private intrusion # TODO: Avoid this private intrusion
model._dims["nO"] = new_nO model._dims["nO"] = new_nO
return model return model
@ -141,10 +141,10 @@ def init(
# Wl = zero_init(ops, Wl.shape) # Wl = zero_init(ops, Wl.shape)
Wl = glorot_uniform_init(ops, Wl.shape) Wl = glorot_uniform_init(ops, Wl.shape)
padl = uniform_init(ops, padl.shape) # type: ignore padl = uniform_init(ops, padl.shape) # type: ignore
# TODO: Experiment with whether better to initialize upper_W # TODO: Experiment with whether better to initialize output_W
model.set_param("lower_W", Wl) model.set_param("hidden_W", Wl)
model.set_param("lower_b", bl) model.set_param("hidden_b", bl)
model.set_param("lower_pad", padl) model.set_param("hidden_pad", padl)
# model = _lsuv_init(model) # model = _lsuv_init(model)
return model return model
@ -160,12 +160,12 @@ def forward(model, docs_moves: InT, is_train: bool):
docs, moves, actions = docs_moves docs, moves, actions = docs_moves
beam_width = model.attrs["beam_width"] beam_width = model.attrs["beam_width"]
lower_pad = model.get_param("lower_pad") hidden_pad = model.get_param("hidden_pad")
tok2vec = model.get_ref("tok2vec") tok2vec = model.get_ref("tok2vec")
states = moves.init_batch(docs) states = moves.init_batch(docs)
tokvecs, backprop_tok2vec = tok2vec(docs, is_train) tokvecs, backprop_tok2vec = tok2vec(docs, is_train)
tokvecs = model.ops.xp.vstack((tokvecs, lower_pad)) tokvecs = model.ops.xp.vstack((tokvecs, hidden_pad))
feats, backprop_feats = _forward_precomputable_affine(model, tokvecs, is_train) feats, backprop_feats = _forward_precomputable_affine(model, tokvecs, is_train)
seen_mask = _get_seen_mask(model) seen_mask = _get_seen_mask(model)
@ -183,23 +183,23 @@ def _forward_greedy_cpu(model: Model, TransitionSystem moves, states: List[State
for state in states: for state in states:
if not state.is_final(): if not state.is_final():
c_states.push_back(state.c) c_states.push_back(state.c)
weights = get_c_weights(model, <float*>feats.data, seen_mask) weights = _get_c_weights(model, <float*>feats.data, seen_mask)
# Precomputed features have rows for each token, plus one for padding. # Precomputed features have rows for each token, plus one for padding.
cdef int n_tokens = feats.shape[0] - 1 cdef int n_tokens = feats.shape[0] - 1
sizes = get_c_sizes(model, c_states.size(), n_tokens) sizes = _get_c_sizes(model, c_states.size(), n_tokens)
cdef CBlas cblas = model.ops.cblas() cdef CBlas cblas = model.ops.cblas()
scores = _parseC(cblas, moves, &c_states[0], weights, sizes, actions=actions) scores = _parse_batch(cblas, moves, &c_states[0], weights, sizes, actions=actions)
def backprop(dY): def backprop(dY):
raise ValueError(Errors.E1038) raise ValueError(Errors.E1038)
return (states, scores), backprop return (states, scores), backprop
cdef list _parseC(CBlas cblas, TransitionSystem moves, StateC** states, cdef list _parse_batch(CBlas cblas, TransitionSystem moves, StateC** states,
WeightsC weights, SizesC sizes, actions: Optional[List[Ints1d]]=None): WeightsC weights, SizesC sizes, actions: Optional[List[Ints1d]]=None):
cdef int i, j cdef int i, j
cdef vector[StateC *] unfinished cdef vector[StateC *] unfinished
cdef ActivationsC activations = alloc_activations(sizes) cdef ActivationsC activations = _alloc_activations(sizes)
cdef np.ndarray step_scores cdef np.ndarray step_scores
cdef np.ndarray step_actions cdef np.ndarray step_actions
@ -208,7 +208,7 @@ cdef list _parseC(CBlas cblas, TransitionSystem moves, StateC** states,
step_scores = numpy.empty((sizes.states, sizes.classes), dtype="f") step_scores = numpy.empty((sizes.states, sizes.classes), dtype="f")
step_actions = actions[0] if actions is not None else None step_actions = actions[0] if actions is not None else None
with nogil: with nogil:
predict_states(cblas, &activations, <float*>step_scores.data, states, &weights, sizes) _predict_states(cblas, &activations, <float*>step_scores.data, states, &weights, sizes)
if actions is None: if actions is None:
# Validate actions, argmax, take action. # Validate actions, argmax, take action.
c_transition_batch(moves, states, <const float*>step_scores.data, sizes.classes, c_transition_batch(moves, states, <const float*>step_scores.data, sizes.classes,
@ -224,7 +224,7 @@ cdef list _parseC(CBlas cblas, TransitionSystem moves, StateC** states,
scores.append(step_scores) scores.append(step_scores)
unfinished.clear() unfinished.clear()
actions = actions[1:] if actions is not None else None actions = actions[1:] if actions is not None else None
free_activations(&activations) _free_activations(&activations)
return scores return scores
@ -232,8 +232,8 @@ cdef list _parseC(CBlas cblas, TransitionSystem moves, StateC** states,
def _forward_fallback(model: Model, moves: TransitionSystem, states: List[StateClass], tokvecs, backprop_tok2vec, feats, backprop_feats, seen_mask, is_train: bool, def _forward_fallback(model: Model, moves: TransitionSystem, states: List[StateClass], tokvecs, backprop_tok2vec, feats, backprop_feats, seen_mask, is_train: bool,
actions: Optional[List[Ints1d]]=None): actions: Optional[List[Ints1d]]=None):
nF = model.get_dim("nF") nF = model.get_dim("nF")
upper = model.get_ref("upper") output = model.get_ref("output")
lower_b = model.get_param("lower_b") hidden_b = model.get_param("hidden_b")
nH = model.get_dim("nH") nH = model.get_dim("nH")
nP = model.get_dim("nP") nP = model.get_dim("nP")
@ -260,13 +260,13 @@ def _forward_fallback(model: Model, moves: TransitionSystem, states: List[StateC
# Sum the state features, add the bias and apply the activation (maxout) # Sum the state features, add the bias and apply the activation (maxout)
# to create the state vectors. # to create the state vectors.
preacts2f = feats[ids, arange].sum(axis=1) # type: ignore preacts2f = feats[ids, arange].sum(axis=1) # type: ignore
preacts2f += lower_b preacts2f += hidden_b
preacts = ops.reshape3f(preacts2f, preacts2f.shape[0], nH, nP) preacts = ops.reshape3f(preacts2f, preacts2f.shape[0], nH, nP)
assert preacts.shape[0] == len(batch.get_unfinished_states()), preacts.shape assert preacts.shape[0] == len(batch.get_unfinished_states()), preacts.shape
statevecs, which = ops.maxout(preacts) statevecs, which = ops.maxout(preacts)
# We don't use upper's backprop, since we want to backprop for # We don't use output's backprop, since we want to backprop for
# all states at once, rather than a single state. # all states at once, rather than a single state.
scores = upper.predict(statevecs) scores = output.predict(statevecs)
scores[:, seen_mask] = ops.xp.nanmin(scores) scores[:, seen_mask] = ops.xp.nanmin(scores)
# Transition the states, filtering out any that are finished. # Transition the states, filtering out any that are finished.
cpu_scores = ops.to_numpy(scores) cpu_scores = ops.to_numpy(scores)
@ -295,23 +295,23 @@ def _forward_fallback(model: Model, moves: TransitionSystem, states: List[StateC
if (d_scores[:, clas] < 0).any(): if (d_scores[:, clas] < 0).any():
model.attrs["unseen_classes"].remove(clas) model.attrs["unseen_classes"].remove(clas)
d_scores *= seen_mask == False d_scores *= seen_mask == False
# Calculate the gradients for the parameters of the upper layer. # Calculate the gradients for the parameters of the output layer.
# The weight gemm is (nS, nO) @ (nS, nH).T # The weight gemm is (nS, nO) @ (nS, nH).T
upper.inc_grad("b", d_scores.sum(axis=0)) output.inc_grad("b", d_scores.sum(axis=0))
upper.inc_grad("W", ops.gemm(d_scores, statevecs, trans1=True)) output.inc_grad("W", ops.gemm(d_scores, statevecs, trans1=True))
# Now calculate d_statevecs, by backproping through the upper linear layer. # Now calculate d_statevecs, by backproping through the output linear layer.
# This gemm is (nS, nO) @ (nO, nH) # This gemm is (nS, nO) @ (nO, nH)
upper_W = upper.get_param("W") output_W = output.get_param("W")
d_statevecs = ops.gemm(d_scores, upper_W) d_statevecs = ops.gemm(d_scores, output_W)
# Backprop through the maxout activation # Backprop through the maxout activation
d_preacts = ops.backprop_maxout(d_statevecs, which, nP) d_preacts = ops.backprop_maxout(d_statevecs, which, nP)
d_preacts2f = ops.reshape2f(d_preacts, d_preacts.shape[0], nH * nP) d_preacts2f = ops.reshape2f(d_preacts, d_preacts.shape[0], nH * nP)
model.inc_grad("lower_b", d_preacts2f.sum(axis=0)) model.inc_grad("hidden_b", d_preacts2f.sum(axis=0))
# We don't need to backprop the summation, because we pass back the IDs instead # We don't need to backprop the summation, because we pass back the IDs instead
d_state_features = backprop_feats((d_preacts2f, ids)) d_state_features = backprop_feats((d_preacts2f, ids))
d_tokvecs = ops.alloc2f(tokvecs.shape[0], tokvecs.shape[1]) d_tokvecs = ops.alloc2f(tokvecs.shape[0], tokvecs.shape[1])
ops.scatter_add(d_tokvecs, ids, d_state_features) ops.scatter_add(d_tokvecs, ids, d_state_features)
model.inc_grad("lower_pad", d_tokvecs[-1]) model.inc_grad("hidden_pad", d_tokvecs[-1])
return (backprop_tok2vec(d_tokvecs[:-1]), None) return (backprop_tok2vec(d_tokvecs[:-1]), None)
return (list(batch), all_scores), backprop_parser return (list(batch), all_scores), backprop_parser
@ -323,10 +323,10 @@ def _forward_reference(
"""Slow reference implementation, without the precomputation""" """Slow reference implementation, without the precomputation"""
nF = model.get_dim("nF") nF = model.get_dim("nF")
tok2vec = model.get_ref("tok2vec") tok2vec = model.get_ref("tok2vec")
upper = model.get_ref("upper") output = model.get_ref("output")
lower_pad = model.get_param("lower_pad") hidden_pad = model.get_param("hidden_pad")
lower_W = model.get_param("lower_W") hidden_W = model.get_param("hidden_W")
lower_b = model.get_param("lower_b") hidden_b = model.get_param("hidden_b")
nH = model.get_dim("nH") nH = model.get_dim("nH")
nP = model.get_dim("nP") nP = model.get_dim("nP")
nO = model.get_dim("nO") nO = model.get_dim("nO")
@ -336,7 +336,7 @@ def _forward_reference(
docs, moves = docs_moves docs, moves = docs_moves
states = moves.init_batch(docs) states = moves.init_batch(docs)
tokvecs, backprop_tok2vec = tok2vec(docs, is_train) tokvecs, backprop_tok2vec = tok2vec(docs, is_train)
tokvecs = model.ops.xp.vstack((tokvecs, lower_pad)) tokvecs = model.ops.xp.vstack((tokvecs, hidden_pad))
all_ids = [] all_ids = []
all_which = [] all_which = []
all_statevecs = [] all_statevecs = []
@ -353,13 +353,13 @@ def _forward_reference(
# to create the state vectors. # to create the state vectors.
tokfeats3f = tokvecs[ids] tokfeats3f = tokvecs[ids]
tokfeats = model.ops.reshape2f(tokfeats3f, tokfeats3f.shape[0], -1) tokfeats = model.ops.reshape2f(tokfeats3f, tokfeats3f.shape[0], -1)
preacts2f = model.ops.gemm(tokfeats, lower_W, trans2=True) preacts2f = model.ops.gemm(tokfeats, hidden_W, trans2=True)
preacts2f += lower_b preacts2f += hidden_b
preacts = model.ops.reshape3f(preacts2f, preacts2f.shape[0], nH, nP) preacts = model.ops.reshape3f(preacts2f, preacts2f.shape[0], nH, nP)
statevecs, which = ops.maxout(preacts) statevecs, which = ops.maxout(preacts)
# We don't use upper's backprop, since we want to backprop for # We don't use output's backprop, since we want to backprop for
# all states at once, rather than a single state. # all states at once, rather than a single state.
scores = upper.predict(statevecs) scores = output.predict(statevecs)
scores[:, seen_mask] = model.ops.xp.nanmin(scores) scores[:, seen_mask] = model.ops.xp.nanmin(scores)
# Transition the states, filtering out any that are finished. # Transition the states, filtering out any that are finished.
next_states = moves.transition_states(next_states, scores) next_states = moves.transition_states(next_states, scores)
@ -390,28 +390,28 @@ def _forward_reference(
d_scores *= seen_mask == False d_scores *= seen_mask == False
assert statevecs.shape == (nS, nH), statevecs.shape assert statevecs.shape == (nS, nH), statevecs.shape
assert d_scores.shape == (nS, nO), d_scores.shape assert d_scores.shape == (nS, nO), d_scores.shape
# Calculate the gradients for the parameters of the upper layer. # Calculate the gradients for the parameters of the output layer.
# The weight gemm is (nS, nO) @ (nS, nH).T # The weight gemm is (nS, nO) @ (nS, nH).T
upper.inc_grad("b", d_scores.sum(axis=0)) output.inc_grad("b", d_scores.sum(axis=0))
upper.inc_grad("W", model.ops.gemm(d_scores, statevecs, trans1=True)) output.inc_grad("W", model.ops.gemm(d_scores, statevecs, trans1=True))
# Now calculate d_statevecs, by backproping through the upper linear layer. # Now calculate d_statevecs, by backproping through the output linear layer.
# This gemm is (nS, nO) @ (nO, nH) # This gemm is (nS, nO) @ (nO, nH)
upper_W = upper.get_param("W") output_W = output.get_param("W")
d_statevecs = model.ops.gemm(d_scores, upper_W) d_statevecs = model.ops.gemm(d_scores, output_W)
# Backprop through the maxout activation # Backprop through the maxout activation
d_preacts = model.ops.backprop_maxout(d_statevecs, which, nP) d_preacts = model.ops.backprop_maxout(d_statevecs, which, nP)
d_preacts2f = model.ops.reshape2f(d_preacts, d_preacts.shape[0], nH * nP) d_preacts2f = model.ops.reshape2f(d_preacts, d_preacts.shape[0], nH * nP)
# Now increment the gradients for the lower layer. # Now increment the gradients for the hidden layer.
# The gemm here is (nS, nH*nP) @ (nS, nF*nI) # The gemm here is (nS, nH*nP) @ (nS, nF*nI)
model.inc_grad("lower_b", d_preacts2f.sum(axis=0)) model.inc_grad("hidden_b", d_preacts2f.sum(axis=0))
model.inc_grad("lower_W", model.ops.gemm(d_preacts2f, tokfeats, trans1=True)) model.inc_grad("hidden_W", model.ops.gemm(d_preacts2f, tokfeats, trans1=True))
# Caclulate d_tokfeats # Caclulate d_tokfeats
# The gemm here is (nS, nH*nP) @ (nH*nP, nF*nI) # The gemm here is (nS, nH*nP) @ (nH*nP, nF*nI)
d_tokfeats = model.ops.gemm(d_preacts2f, lower_W) d_tokfeats = model.ops.gemm(d_preacts2f, hidden_W)
# Get the gradients of the tokvecs and the padding # Get the gradients of the tokvecs and the padding
d_tokfeats3f = model.ops.reshape3f(d_tokfeats, nS, nF, nI) d_tokfeats3f = model.ops.reshape3f(d_tokfeats, nS, nF, nI)
model.ops.scatter_add(d_tokvecs, ids, d_tokfeats3f) model.ops.scatter_add(d_tokvecs, ids, d_tokfeats3f)
model.inc_grad("lower_pad", d_tokvecs[-1]) model.inc_grad("hidden_pad", d_tokvecs[-1])
return (backprop_tok2vec(d_tokvecs[:-1]), None) return (backprop_tok2vec(d_tokvecs[:-1]), None)
return (states, all_scores), backprop_parser return (states, all_scores), backprop_parser
@ -425,7 +425,7 @@ def _get_seen_mask(model: Model) -> numpy.array[bool, 1]:
def _forward_precomputable_affine(model, X: Floats2d, is_train: bool): def _forward_precomputable_affine(model, X: Floats2d, is_train: bool):
W: Floats2d = model.get_param("lower_W") W: Floats2d = model.get_param("hidden_W")
nF = model.get_dim("nF") nF = model.get_dim("nF")
nH = model.get_dim("nH") nH = model.get_dim("nH")
nP = model.get_dim("nP") nP = model.get_dim("nP")
@ -456,7 +456,7 @@ def _forward_precomputable_affine(model, X: Floats2d, is_train: bool):
dXf = model.ops.gemm(dY, W) dXf = model.ops.gemm(dY, W)
Xf = X[ids].reshape((ids.shape[0], -1)) Xf = X[ids].reshape((ids.shape[0], -1))
dW = model.ops.gemm(dY, Xf, trans1=True) dW = model.ops.gemm(dY, Xf, trans1=True)
model.inc_grad("lower_W", dW) model.inc_grad("hidden_W", dW)
return model.ops.reshape3f(dXf, dXf.shape[0], nF, nI) return model.ops.reshape3f(dXf, dXf.shape[0], nF, nI)
return Yf, backward return Yf, backward
@ -482,7 +482,7 @@ def _lsuv_init(model: Model):
we set the maxout weights to values that empirically result in we set the maxout weights to values that empirically result in
whitened outputs given whitened inputs. whitened outputs given whitened inputs.
""" """
W = model.maybe_get_param("lower_W") W = model.maybe_get_param("hidden_W")
if W is not None and W.any(): if W is not None and W.any():
return return
@ -523,66 +523,66 @@ def _lsuv_init(model: Model):
tol_var = 0.01 tol_var = 0.01
tol_mean = 0.01 tol_mean = 0.01
t_max = 10 t_max = 10
W = cast(Floats4d, model.get_param("lower_W").copy()) W = cast(Floats4d, model.get_param("hidden_W").copy())
b = cast(Floats2d, model.get_param("lower_b").copy()) b = cast(Floats2d, model.get_param("hidden_b").copy())
for t_i in range(t_max): for t_i in range(t_max):
acts1 = predict(ids, tokvecs) acts1 = predict(ids, tokvecs)
var = model.ops.xp.var(acts1) var = model.ops.xp.var(acts1)
mean = model.ops.xp.mean(acts1) mean = model.ops.xp.mean(acts1)
if abs(var - 1.0) >= tol_var: if abs(var - 1.0) >= tol_var:
W /= model.ops.xp.sqrt(var) W /= model.ops.xp.sqrt(var)
model.set_param("lower_W", W) model.set_param("hidden_W", W)
elif abs(mean) >= tol_mean: elif abs(mean) >= tol_mean:
b -= mean b -= mean
model.set_param("lower_b", b) model.set_param("hidden_b", b)
else: else:
break break
return model return model
cdef WeightsC get_c_weights(model, const float* feats, np.ndarray[np.npy_bool, ndim=1] seen_mask) except *: cdef WeightsC _get_c_weights(model, const float* feats, np.ndarray[np.npy_bool, ndim=1] seen_mask) except *:
upper = model.get_ref("upper") output = model.get_ref("output")
cdef np.ndarray lower_b = model.get_param("lower_b") cdef np.ndarray hidden_b = model.get_param("hidden_b")
cdef np.ndarray upper_W = upper.get_param("W") cdef np.ndarray output_W = output.get_param("W")
cdef np.ndarray upper_b = upper.get_param("b") cdef np.ndarray output_b = output.get_param("b")
cdef WeightsC output cdef WeightsC weights
output.feat_weights = feats weights.feat_weights = feats
output.feat_bias = <const float*>lower_b.data weights.feat_bias = <const float*>hidden_b.data
output.hidden_weights = <const float *> upper_W.data weights.hidden_weights = <const float *> output_W.data
output.hidden_bias = <const float *> upper_b.data weights.hidden_bias = <const float *> output_b.data
output.seen_mask = <const int8_t*> seen_mask.data weights.seen_mask = <const int8_t*> seen_mask.data
return output return weights
cdef SizesC get_c_sizes(model, int batch_size, int tokens) except *: cdef SizesC _get_c_sizes(model, int batch_size, int tokens) except *:
cdef SizesC output cdef SizesC sizes
output.states = batch_size sizes.states = batch_size
output.classes = model.get_dim("nO") sizes.classes = model.get_dim("nO")
output.hiddens = model.get_dim("nH") sizes.hiddens = model.get_dim("nH")
output.pieces = model.get_dim("nP") sizes.pieces = model.get_dim("nP")
output.feats = model.get_dim("nF") sizes.feats = model.get_dim("nF")
output.embed_width = model.get_dim("nI") sizes.embed_width = model.get_dim("nI")
output.tokens = tokens sizes.tokens = tokens
return output return sizes
cdef ActivationsC alloc_activations(SizesC n) nogil: cdef ActivationsC _alloc_activations(SizesC n) nogil:
cdef ActivationsC A cdef ActivationsC A
memset(&A, 0, sizeof(A)) memset(&A, 0, sizeof(A))
resize_activations(&A, n) _resize_activations(&A, n)
return A return A
cdef void free_activations(const ActivationsC* A) nogil: cdef void _free_activations(const ActivationsC* A) nogil:
free(A.token_ids) free(A.token_ids)
free(A.unmaxed) free(A.unmaxed)
free(A.hiddens) free(A.hiddens)
free(A.is_valid) free(A.is_valid)
cdef void resize_activations(ActivationsC* A, SizesC n) nogil: cdef void _resize_activations(ActivationsC* A, SizesC n) nogil:
if n.states <= A._max_size: if n.states <= A._max_size:
A._curr_size = n.states A._curr_size = n.states
return return
@ -605,12 +605,12 @@ cdef void resize_activations(ActivationsC* A, SizesC n) nogil:
A._curr_size = n.states A._curr_size = n.states
cdef void predict_states(CBlas cblas, ActivationsC* A, float* scores, StateC** states, const WeightsC* W, SizesC n) nogil: cdef void _predict_states(CBlas cblas, ActivationsC* A, float* scores, StateC** states, const WeightsC* W, SizesC n) nogil:
resize_activations(A, n) _resize_activations(A, n)
for i in range(n.states): for i in range(n.states):
states[i].set_context_tokens(&A.token_ids[i*n.feats], n.feats) states[i].set_context_tokens(&A.token_ids[i*n.feats], n.feats)
memset(A.unmaxed, 0, n.states * n.hiddens * n.pieces * sizeof(float)) memset(A.unmaxed, 0, n.states * n.hiddens * n.pieces * sizeof(float))
sum_state_features(cblas, A.unmaxed, W.feat_weights, A.token_ids, n) _sum_state_features(cblas, A.unmaxed, W.feat_weights, A.token_ids, n)
for i in range(n.states): for i in range(n.states):
VecVec.add_i(&A.unmaxed[i*n.hiddens*n.pieces], VecVec.add_i(&A.unmaxed[i*n.hiddens*n.pieces],
W.feat_bias, 1., n.hiddens * n.pieces) W.feat_bias, 1., n.hiddens * n.pieces)
@ -641,7 +641,7 @@ cdef void predict_states(CBlas cblas, ActivationsC* A, float* scores, StateC** s
scores[i*n.classes+j] = min_ scores[i*n.classes+j] = min_
cdef void sum_state_features(CBlas cblas, float* output, cdef void _sum_state_features(CBlas cblas, float* output,
const float* cached, const int* token_ids, SizesC n) nogil: const float* cached, const int* token_ids, SizesC n) nogil:
cdef int idx, b, f, i cdef int idx, b, f, i
cdef const float* feature cdef const float* feature

View File

@ -263,12 +263,12 @@ def test_serialize_custom_nlp():
nlp2 = spacy.load(d) nlp2 = spacy.load(d)
model = nlp2.get_pipe("parser").model model = nlp2.get_pipe("parser").model
assert model.get_ref("tok2vec") is not None assert model.get_ref("tok2vec") is not None
assert model.has_param("lower_W") assert model.has_param("hidden_W")
assert model.has_param("lower_b") assert model.has_param("hidden_b")
upper = model.get_ref("upper") output = model.get_ref("output")
assert upper is not None assert output is not None
assert upper.has_param("W") assert output.has_param("W")
assert upper.has_param("b") assert output.has_param("b")
@pytest.mark.parametrize("parser_config_string", [parser_config_string_upper]) @pytest.mark.parametrize("parser_config_string", [parser_config_string_upper])
@ -285,12 +285,12 @@ def test_serialize_parser(parser_config_string):
nlp2 = spacy.load(d) nlp2 = spacy.load(d)
model = nlp2.get_pipe("parser").model model = nlp2.get_pipe("parser").model
assert model.get_ref("tok2vec") is not None assert model.get_ref("tok2vec") is not None
assert model.has_param("lower_W") assert model.has_param("hidden_W")
assert model.has_param("lower_b") assert model.has_param("hidden_b")
upper = model.get_ref("upper") output = model.get_ref("output")
assert upper is not None assert output is not None
assert upper.has_param("b") assert output.has_param("b")
assert upper.has_param("W") assert output.has_param("W")
def test_config_nlp_roundtrip(): def test_config_nlp_roundtrip():