Rename some identifiers in the parser refactor (#10935)

* Rename _parseC to _parse_batch

* tb_framework: prefix many auxiliary functions with underscore

To clearly state the intent that they are private.

* Rename `lower` to `hidden`, `upper` to `output`
This commit is contained in:
Daniël de Kok 2022-06-09 13:45:30 +02:00 committed by GitHub
parent 63e90dd6a1
commit bc36c71982
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 117 additions and 117 deletions

View File

@ -45,28 +45,28 @@ def TransitionModel(
tok2vec_projected = chain(tok2vec, list2array(), Linear(hidden_width, t2v_width)) # type: ignore
tok2vec_projected.set_dim("nO", hidden_width)
# FIXME: we use `upper` as a container for the upper layer's
# FIXME: we use `output` as a container for the output layer's
# weights and biases. Thinc optimizers cannot handle resizing
# of parameters. So, when the parser model is resized, we
# construct a new `upper` layer, which has a different key in
# construct a new `output` layer, which has a different key in
# the optimizer. Once the optimizer supports parameter resizing,
# we can replace the `upper` layer by `upper_W` and `upper_b`
# we can replace the `output` layer by `output_W` and `output_b`
# parameters in this model.
upper = Linear(nO=None, nI=hidden_width, init_W=zero_init)
output = Linear(nO=None, nI=hidden_width, init_W=zero_init)
return Model(
name="parser_model",
forward=forward,
init=init,
layers=[tok2vec_projected, upper],
layers=[tok2vec_projected, output],
refs={
"tok2vec": tok2vec_projected,
"upper": upper,
"output": output,
},
params={
"lower_W": None, # Floats2d W for the hidden layer
"lower_b": None, # Floats1d bias for the hidden layer
"lower_pad": None, # Floats1d padding for the hidden layer
"hidden_W": None, # Floats2d W for the hidden layer
"hidden_b": None, # Floats1d bias for the hidden layer
"hidden_pad": None, # Floats1d padding for the hidden layer
},
dims={
"nO": None, # Output size
@ -86,28 +86,28 @@ def TransitionModel(
def resize_output(model: Model, new_nO: int) -> Model:
old_nO = model.maybe_get_dim("nO")
upper = model.get_ref("upper")
output = model.get_ref("output")
if old_nO is None:
model.set_dim("nO", new_nO)
upper.set_dim("nO", new_nO)
upper.initialize()
output.set_dim("nO", new_nO)
output.initialize()
return model
elif new_nO <= old_nO:
return model
elif upper.has_param("W"):
elif output.has_param("W"):
nH = model.get_dim("nH")
new_upper = Linear(nO=new_nO, nI=nH, init_W=zero_init)
new_upper.initialize()
new_W = new_upper.get_param("W")
new_b = new_upper.get_param("b")
old_W = upper.get_param("W")
old_b = upper.get_param("b")
new_output = Linear(nO=new_nO, nI=nH, init_W=zero_init)
new_output.initialize()
new_W = new_output.get_param("W")
new_b = new_output.get_param("b")
old_W = output.get_param("W")
old_b = output.get_param("b")
new_W[:old_nO] = old_W # type: ignore
new_b[:old_nO] = old_b # type: ignore
for i in range(old_nO, new_nO):
model.attrs["unseen_classes"].add(i)
model.layers[-1] = new_upper
model.set_ref("upper", new_upper)
model.layers[-1] = new_output
model.set_ref("output", new_output)
# TODO: Avoid this private intrusion
model._dims["nO"] = new_nO
return model
@ -141,10 +141,10 @@ def init(
# Wl = zero_init(ops, Wl.shape)
Wl = glorot_uniform_init(ops, Wl.shape)
padl = uniform_init(ops, padl.shape) # type: ignore
# TODO: Experiment with whether better to initialize upper_W
model.set_param("lower_W", Wl)
model.set_param("lower_b", bl)
model.set_param("lower_pad", padl)
# TODO: Experiment with whether better to initialize output_W
model.set_param("hidden_W", Wl)
model.set_param("hidden_b", bl)
model.set_param("hidden_pad", padl)
# model = _lsuv_init(model)
return model
@ -160,12 +160,12 @@ def forward(model, docs_moves: InT, is_train: bool):
docs, moves, actions = docs_moves
beam_width = model.attrs["beam_width"]
lower_pad = model.get_param("lower_pad")
hidden_pad = model.get_param("hidden_pad")
tok2vec = model.get_ref("tok2vec")
states = moves.init_batch(docs)
tokvecs, backprop_tok2vec = tok2vec(docs, is_train)
tokvecs = model.ops.xp.vstack((tokvecs, lower_pad))
tokvecs = model.ops.xp.vstack((tokvecs, hidden_pad))
feats, backprop_feats = _forward_precomputable_affine(model, tokvecs, is_train)
seen_mask = _get_seen_mask(model)
@ -183,23 +183,23 @@ def _forward_greedy_cpu(model: Model, TransitionSystem moves, states: List[State
for state in states:
if not state.is_final():
c_states.push_back(state.c)
weights = get_c_weights(model, <float*>feats.data, seen_mask)
weights = _get_c_weights(model, <float*>feats.data, seen_mask)
# Precomputed features have rows for each token, plus one for padding.
cdef int n_tokens = feats.shape[0] - 1
sizes = get_c_sizes(model, c_states.size(), n_tokens)
sizes = _get_c_sizes(model, c_states.size(), n_tokens)
cdef CBlas cblas = model.ops.cblas()
scores = _parseC(cblas, moves, &c_states[0], weights, sizes, actions=actions)
scores = _parse_batch(cblas, moves, &c_states[0], weights, sizes, actions=actions)
def backprop(dY):
raise ValueError(Errors.E1038)
return (states, scores), backprop
cdef list _parseC(CBlas cblas, TransitionSystem moves, StateC** states,
WeightsC weights, SizesC sizes, actions: Optional[List[Ints1d]]=None):
cdef list _parse_batch(CBlas cblas, TransitionSystem moves, StateC** states,
WeightsC weights, SizesC sizes, actions: Optional[List[Ints1d]]=None):
cdef int i, j
cdef vector[StateC *] unfinished
cdef ActivationsC activations = alloc_activations(sizes)
cdef ActivationsC activations = _alloc_activations(sizes)
cdef np.ndarray step_scores
cdef np.ndarray step_actions
@ -208,7 +208,7 @@ cdef list _parseC(CBlas cblas, TransitionSystem moves, StateC** states,
step_scores = numpy.empty((sizes.states, sizes.classes), dtype="f")
step_actions = actions[0] if actions is not None else None
with nogil:
predict_states(cblas, &activations, <float*>step_scores.data, states, &weights, sizes)
_predict_states(cblas, &activations, <float*>step_scores.data, states, &weights, sizes)
if actions is None:
# Validate actions, argmax, take action.
c_transition_batch(moves, states, <const float*>step_scores.data, sizes.classes,
@ -224,7 +224,7 @@ cdef list _parseC(CBlas cblas, TransitionSystem moves, StateC** states,
scores.append(step_scores)
unfinished.clear()
actions = actions[1:] if actions is not None else None
free_activations(&activations)
_free_activations(&activations)
return scores
@ -232,8 +232,8 @@ cdef list _parseC(CBlas cblas, TransitionSystem moves, StateC** states,
def _forward_fallback(model: Model, moves: TransitionSystem, states: List[StateClass], tokvecs, backprop_tok2vec, feats, backprop_feats, seen_mask, is_train: bool,
actions: Optional[List[Ints1d]]=None):
nF = model.get_dim("nF")
upper = model.get_ref("upper")
lower_b = model.get_param("lower_b")
output = model.get_ref("output")
hidden_b = model.get_param("hidden_b")
nH = model.get_dim("nH")
nP = model.get_dim("nP")
@ -260,13 +260,13 @@ def _forward_fallback(model: Model, moves: TransitionSystem, states: List[StateC
# Sum the state features, add the bias and apply the activation (maxout)
# to create the state vectors.
preacts2f = feats[ids, arange].sum(axis=1) # type: ignore
preacts2f += lower_b
preacts2f += hidden_b
preacts = ops.reshape3f(preacts2f, preacts2f.shape[0], nH, nP)
assert preacts.shape[0] == len(batch.get_unfinished_states()), preacts.shape
statevecs, which = ops.maxout(preacts)
# We don't use upper's backprop, since we want to backprop for
# We don't use output's backprop, since we want to backprop for
# all states at once, rather than a single state.
scores = upper.predict(statevecs)
scores = output.predict(statevecs)
scores[:, seen_mask] = ops.xp.nanmin(scores)
# Transition the states, filtering out any that are finished.
cpu_scores = ops.to_numpy(scores)
@ -295,23 +295,23 @@ def _forward_fallback(model: Model, moves: TransitionSystem, states: List[StateC
if (d_scores[:, clas] < 0).any():
model.attrs["unseen_classes"].remove(clas)
d_scores *= seen_mask == False
# Calculate the gradients for the parameters of the upper layer.
# Calculate the gradients for the parameters of the output layer.
# The weight gemm is (nS, nO) @ (nS, nH).T
upper.inc_grad("b", d_scores.sum(axis=0))
upper.inc_grad("W", ops.gemm(d_scores, statevecs, trans1=True))
# Now calculate d_statevecs, by backproping through the upper linear layer.
output.inc_grad("b", d_scores.sum(axis=0))
output.inc_grad("W", ops.gemm(d_scores, statevecs, trans1=True))
# Now calculate d_statevecs, by backproping through the output linear layer.
# This gemm is (nS, nO) @ (nO, nH)
upper_W = upper.get_param("W")
d_statevecs = ops.gemm(d_scores, upper_W)
output_W = output.get_param("W")
d_statevecs = ops.gemm(d_scores, output_W)
# Backprop through the maxout activation
d_preacts = ops.backprop_maxout(d_statevecs, which, nP)
d_preacts2f = ops.reshape2f(d_preacts, d_preacts.shape[0], nH * nP)
model.inc_grad("lower_b", d_preacts2f.sum(axis=0))
model.inc_grad("hidden_b", d_preacts2f.sum(axis=0))
# We don't need to backprop the summation, because we pass back the IDs instead
d_state_features = backprop_feats((d_preacts2f, ids))
d_tokvecs = ops.alloc2f(tokvecs.shape[0], tokvecs.shape[1])
ops.scatter_add(d_tokvecs, ids, d_state_features)
model.inc_grad("lower_pad", d_tokvecs[-1])
model.inc_grad("hidden_pad", d_tokvecs[-1])
return (backprop_tok2vec(d_tokvecs[:-1]), None)
return (list(batch), all_scores), backprop_parser
@ -323,10 +323,10 @@ def _forward_reference(
"""Slow reference implementation, without the precomputation"""
nF = model.get_dim("nF")
tok2vec = model.get_ref("tok2vec")
upper = model.get_ref("upper")
lower_pad = model.get_param("lower_pad")
lower_W = model.get_param("lower_W")
lower_b = model.get_param("lower_b")
output = model.get_ref("output")
hidden_pad = model.get_param("hidden_pad")
hidden_W = model.get_param("hidden_W")
hidden_b = model.get_param("hidden_b")
nH = model.get_dim("nH")
nP = model.get_dim("nP")
nO = model.get_dim("nO")
@ -336,7 +336,7 @@ def _forward_reference(
docs, moves = docs_moves
states = moves.init_batch(docs)
tokvecs, backprop_tok2vec = tok2vec(docs, is_train)
tokvecs = model.ops.xp.vstack((tokvecs, lower_pad))
tokvecs = model.ops.xp.vstack((tokvecs, hidden_pad))
all_ids = []
all_which = []
all_statevecs = []
@ -353,13 +353,13 @@ def _forward_reference(
# to create the state vectors.
tokfeats3f = tokvecs[ids]
tokfeats = model.ops.reshape2f(tokfeats3f, tokfeats3f.shape[0], -1)
preacts2f = model.ops.gemm(tokfeats, lower_W, trans2=True)
preacts2f += lower_b
preacts2f = model.ops.gemm(tokfeats, hidden_W, trans2=True)
preacts2f += hidden_b
preacts = model.ops.reshape3f(preacts2f, preacts2f.shape[0], nH, nP)
statevecs, which = ops.maxout(preacts)
# We don't use upper's backprop, since we want to backprop for
# We don't use output's backprop, since we want to backprop for
# all states at once, rather than a single state.
scores = upper.predict(statevecs)
scores = output.predict(statevecs)
scores[:, seen_mask] = model.ops.xp.nanmin(scores)
# Transition the states, filtering out any that are finished.
next_states = moves.transition_states(next_states, scores)
@ -390,28 +390,28 @@ def _forward_reference(
d_scores *= seen_mask == False
assert statevecs.shape == (nS, nH), statevecs.shape
assert d_scores.shape == (nS, nO), d_scores.shape
# Calculate the gradients for the parameters of the upper layer.
# Calculate the gradients for the parameters of the output layer.
# The weight gemm is (nS, nO) @ (nS, nH).T
upper.inc_grad("b", d_scores.sum(axis=0))
upper.inc_grad("W", model.ops.gemm(d_scores, statevecs, trans1=True))
# Now calculate d_statevecs, by backproping through the upper linear layer.
output.inc_grad("b", d_scores.sum(axis=0))
output.inc_grad("W", model.ops.gemm(d_scores, statevecs, trans1=True))
# Now calculate d_statevecs, by backproping through the output linear layer.
# This gemm is (nS, nO) @ (nO, nH)
upper_W = upper.get_param("W")
d_statevecs = model.ops.gemm(d_scores, upper_W)
output_W = output.get_param("W")
d_statevecs = model.ops.gemm(d_scores, output_W)
# Backprop through the maxout activation
d_preacts = model.ops.backprop_maxout(d_statevecs, which, nP)
d_preacts2f = model.ops.reshape2f(d_preacts, d_preacts.shape[0], nH * nP)
# Now increment the gradients for the lower layer.
# Now increment the gradients for the hidden layer.
# The gemm here is (nS, nH*nP) @ (nS, nF*nI)
model.inc_grad("lower_b", d_preacts2f.sum(axis=0))
model.inc_grad("lower_W", model.ops.gemm(d_preacts2f, tokfeats, trans1=True))
model.inc_grad("hidden_b", d_preacts2f.sum(axis=0))
model.inc_grad("hidden_W", model.ops.gemm(d_preacts2f, tokfeats, trans1=True))
# Caclulate d_tokfeats
# The gemm here is (nS, nH*nP) @ (nH*nP, nF*nI)
d_tokfeats = model.ops.gemm(d_preacts2f, lower_W)
d_tokfeats = model.ops.gemm(d_preacts2f, hidden_W)
# Get the gradients of the tokvecs and the padding
d_tokfeats3f = model.ops.reshape3f(d_tokfeats, nS, nF, nI)
model.ops.scatter_add(d_tokvecs, ids, d_tokfeats3f)
model.inc_grad("lower_pad", d_tokvecs[-1])
model.inc_grad("hidden_pad", d_tokvecs[-1])
return (backprop_tok2vec(d_tokvecs[:-1]), None)
return (states, all_scores), backprop_parser
@ -425,7 +425,7 @@ def _get_seen_mask(model: Model) -> numpy.array[bool, 1]:
def _forward_precomputable_affine(model, X: Floats2d, is_train: bool):
W: Floats2d = model.get_param("lower_W")
W: Floats2d = model.get_param("hidden_W")
nF = model.get_dim("nF")
nH = model.get_dim("nH")
nP = model.get_dim("nP")
@ -456,7 +456,7 @@ def _forward_precomputable_affine(model, X: Floats2d, is_train: bool):
dXf = model.ops.gemm(dY, W)
Xf = X[ids].reshape((ids.shape[0], -1))
dW = model.ops.gemm(dY, Xf, trans1=True)
model.inc_grad("lower_W", dW)
model.inc_grad("hidden_W", dW)
return model.ops.reshape3f(dXf, dXf.shape[0], nF, nI)
return Yf, backward
@ -482,7 +482,7 @@ def _lsuv_init(model: Model):
we set the maxout weights to values that empirically result in
whitened outputs given whitened inputs.
"""
W = model.maybe_get_param("lower_W")
W = model.maybe_get_param("hidden_W")
if W is not None and W.any():
return
@ -523,66 +523,66 @@ def _lsuv_init(model: Model):
tol_var = 0.01
tol_mean = 0.01
t_max = 10
W = cast(Floats4d, model.get_param("lower_W").copy())
b = cast(Floats2d, model.get_param("lower_b").copy())
W = cast(Floats4d, model.get_param("hidden_W").copy())
b = cast(Floats2d, model.get_param("hidden_b").copy())
for t_i in range(t_max):
acts1 = predict(ids, tokvecs)
var = model.ops.xp.var(acts1)
mean = model.ops.xp.mean(acts1)
if abs(var - 1.0) >= tol_var:
W /= model.ops.xp.sqrt(var)
model.set_param("lower_W", W)
model.set_param("hidden_W", W)
elif abs(mean) >= tol_mean:
b -= mean
model.set_param("lower_b", b)
model.set_param("hidden_b", b)
else:
break
return model
cdef WeightsC get_c_weights(model, const float* feats, np.ndarray[np.npy_bool, ndim=1] seen_mask) except *:
upper = model.get_ref("upper")
cdef np.ndarray lower_b = model.get_param("lower_b")
cdef np.ndarray upper_W = upper.get_param("W")
cdef np.ndarray upper_b = upper.get_param("b")
cdef WeightsC _get_c_weights(model, const float* feats, np.ndarray[np.npy_bool, ndim=1] seen_mask) except *:
output = model.get_ref("output")
cdef np.ndarray hidden_b = model.get_param("hidden_b")
cdef np.ndarray output_W = output.get_param("W")
cdef np.ndarray output_b = output.get_param("b")
cdef WeightsC output
output.feat_weights = feats
output.feat_bias = <const float*>lower_b.data
output.hidden_weights = <const float *> upper_W.data
output.hidden_bias = <const float *> upper_b.data
output.seen_mask = <const int8_t*> seen_mask.data
cdef WeightsC weights
weights.feat_weights = feats
weights.feat_bias = <const float*>hidden_b.data
weights.hidden_weights = <const float *> output_W.data
weights.hidden_bias = <const float *> output_b.data
weights.seen_mask = <const int8_t*> seen_mask.data
return output
return weights
cdef SizesC get_c_sizes(model, int batch_size, int tokens) except *:
cdef SizesC output
output.states = batch_size
output.classes = model.get_dim("nO")
output.hiddens = model.get_dim("nH")
output.pieces = model.get_dim("nP")
output.feats = model.get_dim("nF")
output.embed_width = model.get_dim("nI")
output.tokens = tokens
return output
cdef SizesC _get_c_sizes(model, int batch_size, int tokens) except *:
cdef SizesC sizes
sizes.states = batch_size
sizes.classes = model.get_dim("nO")
sizes.hiddens = model.get_dim("nH")
sizes.pieces = model.get_dim("nP")
sizes.feats = model.get_dim("nF")
sizes.embed_width = model.get_dim("nI")
sizes.tokens = tokens
return sizes
cdef ActivationsC alloc_activations(SizesC n) nogil:
cdef ActivationsC _alloc_activations(SizesC n) nogil:
cdef ActivationsC A
memset(&A, 0, sizeof(A))
resize_activations(&A, n)
_resize_activations(&A, n)
return A
cdef void free_activations(const ActivationsC* A) nogil:
cdef void _free_activations(const ActivationsC* A) nogil:
free(A.token_ids)
free(A.unmaxed)
free(A.hiddens)
free(A.is_valid)
cdef void resize_activations(ActivationsC* A, SizesC n) nogil:
cdef void _resize_activations(ActivationsC* A, SizesC n) nogil:
if n.states <= A._max_size:
A._curr_size = n.states
return
@ -605,12 +605,12 @@ cdef void resize_activations(ActivationsC* A, SizesC n) nogil:
A._curr_size = n.states
cdef void predict_states(CBlas cblas, ActivationsC* A, float* scores, StateC** states, const WeightsC* W, SizesC n) nogil:
resize_activations(A, n)
cdef void _predict_states(CBlas cblas, ActivationsC* A, float* scores, StateC** states, const WeightsC* W, SizesC n) nogil:
_resize_activations(A, n)
for i in range(n.states):
states[i].set_context_tokens(&A.token_ids[i*n.feats], n.feats)
memset(A.unmaxed, 0, n.states * n.hiddens * n.pieces * sizeof(float))
sum_state_features(cblas, A.unmaxed, W.feat_weights, A.token_ids, n)
_sum_state_features(cblas, A.unmaxed, W.feat_weights, A.token_ids, n)
for i in range(n.states):
VecVec.add_i(&A.unmaxed[i*n.hiddens*n.pieces],
W.feat_bias, 1., n.hiddens * n.pieces)
@ -641,7 +641,7 @@ cdef void predict_states(CBlas cblas, ActivationsC* A, float* scores, StateC** s
scores[i*n.classes+j] = min_
cdef void sum_state_features(CBlas cblas, float* output,
cdef void _sum_state_features(CBlas cblas, float* output,
const float* cached, const int* token_ids, SizesC n) nogil:
cdef int idx, b, f, i
cdef const float* feature

View File

@ -263,12 +263,12 @@ def test_serialize_custom_nlp():
nlp2 = spacy.load(d)
model = nlp2.get_pipe("parser").model
assert model.get_ref("tok2vec") is not None
assert model.has_param("lower_W")
assert model.has_param("lower_b")
upper = model.get_ref("upper")
assert upper is not None
assert upper.has_param("W")
assert upper.has_param("b")
assert model.has_param("hidden_W")
assert model.has_param("hidden_b")
output = model.get_ref("output")
assert output is not None
assert output.has_param("W")
assert output.has_param("b")
@pytest.mark.parametrize("parser_config_string", [parser_config_string_upper])
@ -285,12 +285,12 @@ def test_serialize_parser(parser_config_string):
nlp2 = spacy.load(d)
model = nlp2.get_pipe("parser").model
assert model.get_ref("tok2vec") is not None
assert model.has_param("lower_W")
assert model.has_param("lower_b")
upper = model.get_ref("upper")
assert upper is not None
assert upper.has_param("b")
assert upper.has_param("W")
assert model.has_param("hidden_W")
assert model.has_param("hidden_b")
output = model.get_ref("output")
assert output is not None
assert output.has_param("b")
assert output.has_param("W")
def test_config_nlp_roundtrip():