Work through tests

This commit is contained in:
Matthew Honnibal 2021-10-26 01:21:51 +02:00
parent d765a4f8ee
commit c538eaf1c8
5 changed files with 72 additions and 25 deletions

View File

@ -1,6 +1,7 @@
from typing import List, Tuple, Any, Optional from typing import List, Tuple, Any, Optional
from thinc.api import Ops, Model, normal_init, chain, list2array, Linear from thinc.api import Ops, Model, normal_init, chain, list2array, Linear
from thinc.types import Floats1d, Floats2d, Floats3d, Ints2d, Floats4d from thinc.types import Floats1d, Floats2d, Floats3d, Ints2d, Floats4d
import numpy
from ..tokens.doc import Doc from ..tokens.doc import Doc
@ -29,7 +30,7 @@ def TransitionModel(
forward=forward, forward=forward,
init=init, init=init,
layers=[tok2vec_projected], layers=[tok2vec_projected],
refs={"tok2vec": tok2vec}, refs={"tok2vec": tok2vec_projected},
params={ params={
"lower_W": None, # Floats2d W for the hidden layer "lower_W": None, # Floats2d W for the hidden layer
"lower_b": None, # Floats1d bias for the hidden layer "lower_b": None, # Floats1d bias for the hidden layer
@ -77,8 +78,10 @@ def init(
Y: Optional[Tuple[List[State], List[Floats2d]]] = None, Y: Optional[Tuple[List[State], List[Floats2d]]] = None,
): ):
if X is not None: if X is not None:
docs, states = X docs, moves = X
model.get_ref("tok2vec").initialize(X=docs) model.get_ref("tok2vec").initialize(X=docs)
else:
model.get_ref("tok2vec").initialize()
inferred_nO = _infer_nO(Y) inferred_nO = _infer_nO(Y)
if inferred_nO is not None: if inferred_nO is not None:
current_nO = model.maybe_get_dim("nO") current_nO = model.maybe_get_dim("nO")
@ -110,7 +113,8 @@ def init(
_lsuv_init(model) _lsuv_init(model)
def forward(model, docs_moves, is_train): def forward(model, docs_moves: Tuple[List[Doc], TransitionSystem], is_train: bool):
nF = model.get_dim("nF")
tok2vec = model.get_ref("tok2vec") tok2vec = model.get_ref("tok2vec")
lower_pad = model.get_param("lower_pad") lower_pad = model.get_param("lower_pad")
lower_b = model.get_param("lower_b") lower_b = model.get_param("lower_b")
@ -126,13 +130,16 @@ def forward(model, docs_moves, is_train):
all_which = [] all_which = []
all_statevecs = [] all_statevecs = []
all_scores = [] all_scores = []
next_states = list(states) next_states = [s for s in states if not s.is_final()]
unseen_mask = _get_unseen_mask(model) unseen_mask = _get_unseen_mask(model)
ids = numpy.zeros((len(states), nF), dtype="i")
while next_states: while next_states:
ids = moves.get_state_ids(states) ids = ids[: len(next_states)]
for i, state in enumerate(next_states):
state.set_context_tokens(ids, i, nF)
# Sum the state features, add the bias and apply the activation (maxout) # Sum the state features, add the bias and apply the activation (maxout)
# to create the state vectors. # to create the state vectors.
preacts = _sum_state_features(feats, lower_pad, ids) preacts = _sum_state_features(ops, feats, ids)
preacts += lower_b preacts += lower_b
statevecs, which = ops.maxout(preacts) statevecs, which = ops.maxout(preacts)
# Multiply the state-vector by the scores weights and add the bias, # Multiply the state-vector by the scores weights and add the bias,
@ -141,7 +148,7 @@ def forward(model, docs_moves, is_train):
scores += upper_b scores += upper_b
scores[:, unseen_mask == 0] = model.ops.xp.nanmin(scores) scores[:, unseen_mask == 0] = model.ops.xp.nanmin(scores)
# Transition the states, filtering out any that are finished. # Transition the states, filtering out any that are finished.
next_states = moves.transition_states(states, scores) next_states = moves.transition_states(next_states, scores)
all_scores.append(scores) all_scores.append(scores)
if is_train: if is_train:
# Remember intermediate results for the backprop. # Remember intermediate results for the backprop.
@ -204,24 +211,23 @@ def _sum_state_features(ops: Ops, feats: Floats3d, ids: Ints2d, _arange=[]) -> F
def _forward_precomputable_affine(model, X: Floats2d, is_train: bool): def _forward_precomputable_affine(model, X: Floats2d, is_train: bool):
W: Floats4d = model.get_param("lower_W") W: Floats4d = model.get_param("lower_W")
b: Floats2d = model.get_param("lower_b")
pad: Floats4d = model.get_param("lower_pad") pad: Floats4d = model.get_param("lower_pad")
nF = model.get_dim("nF") nF = model.get_dim("nF")
nO = model.get_dim("nO") nH = model.get_dim("nH")
nP = model.get_dim("nP") nP = model.get_dim("nP")
nI = model.get_dim("nI") nI = model.get_dim("nI")
Yf_ = model.ops.gemm(X, model.ops.reshape2f(W, nF * nO * nP, nI), trans2=True) Yf_ = model.ops.gemm(X, model.ops.reshape2f(W, nF * nH * nP, nI), trans2=True)
Yf = model.ops.reshape4f(Yf_, Yf_.shape[0], nF, nO, nP) Yf = model.ops.reshape4f(Yf_, Yf_.shape[0], nF, nH, nP)
Yf = model.ops.xp.vstack((Yf, pad)) Yf = model.ops.xp.vstack((Yf, pad))
def backward(dY_ids: Tuple[Floats3d, Ints2d]): def backward(dY_ids: Tuple[Floats3d, Ints2d]):
# This backprop is particularly tricky, because we get back a different # This backprop is particularly tricky, because we get back a different
# thing from what we put out. We put out an array of shape: # thing from what we put out. We put out an array of shape:
# (nB, nF, nO, nP), and get back: # (nB, nF, nH, nP), and get back:
# (nB, nO, nP) and ids (nB, nF) # (nB, nH, nP) and ids (nB, nF)
# The ids tell us the values of nF, so we would have: # The ids tell us the values of nF, so we would have:
# #
# dYf = zeros((nB, nF, nO, nP)) # dYf = zeros((nB, nF, nH, nP))
# for b in range(nB): # for b in range(nB):
# for f in range(nF): # for f in range(nF):
# dYf[b, ids[b, f]] += dY[b] # dYf[b, ids[b, f]] += dY[b]
@ -230,7 +236,7 @@ def _forward_precomputable_affine(model, X: Floats2d, is_train: bool):
# in the indices. # in the indices.
dY, ids = dY_ids dY, ids = dY_ids
assert dY.ndim == 3 assert dY.ndim == 3
assert dY.shape[1] == nO, dY.shape assert dY.shape[1] == nH, dY.shape
assert dY.shape[2] == nP, dY.shape assert dY.shape[2] == nP, dY.shape
# nB = dY.shape[0] # nB = dY.shape[0]
model.inc_grad( model.inc_grad(
@ -239,14 +245,14 @@ def _forward_precomputable_affine(model, X: Floats2d, is_train: bool):
Xf = model.ops.reshape2f(X[ids], ids.shape[0], nF * nI) Xf = model.ops.reshape2f(X[ids], ids.shape[0], nF * nI)
model.inc_grad("lower_b", dY.sum(axis=0)) # type: ignore model.inc_grad("lower_b", dY.sum(axis=0)) # type: ignore
dY = model.ops.reshape2f(dY, dY.shape[0], nO * nP) dY = model.ops.reshape2f(dY, dY.shape[0], nH * nP)
Wopfi = W.transpose((1, 2, 0, 3)) Wopfi = W.transpose((1, 2, 0, 3))
Wopfi = Wopfi.reshape((nO * nP, nF * nI)) Wopfi = Wopfi.reshape((nH * nP, nF * nI))
dXf = model.ops.gemm(dY.reshape((dY.shape[0], nO * nP)), Wopfi) dXf = model.ops.gemm(dY.reshape((dY.shape[0], nH * nP)), Wopfi)
dWopfi = model.ops.gemm(dY, Xf, trans1=True) dWopfi = model.ops.gemm(dY, Xf, trans1=True)
dWopfi = dWopfi.reshape((nO, nP, nF, nI)) dWopfi = dWopfi.reshape((nH, nP, nF, nI))
# (o, p, f, i) --> (f, o, p, i) # (o, p, f, i) --> (f, o, p, i)
dWopfi = dWopfi.transpose((2, 0, 1, 3)) dWopfi = dWopfi.transpose((2, 0, 1, 3))
model.inc_grad("W", dWopfi) model.inc_grad("W", dWopfi)

View File

@ -180,3 +180,6 @@ cdef class StateClass:
def clone(self, StateClass src): def clone(self, StateClass src):
self.c.clone(src.c) self.c.clone(src.c)
def set_context_tokens(self, int[:, :] output, int row, int n_feats):
self.c.set_context_tokens(&output[row, 0], n_feats)

View File

@ -1,6 +1,8 @@
# cython: infer_types=True # cython: infer_types=True
from __future__ import print_function from __future__ import print_function
from cymem.cymem cimport Pool from cymem.cymem cimport Pool
from libc.stdlib cimport calloc, free
from libcpp.vector cimport vector
from collections import Counter from collections import Counter
import srsly import srsly
@ -141,6 +143,16 @@ cdef class TransitionSystem:
action.do(state.c, action.label) action.do(state.c, action.label)
state.c.history.push_back(action.clas) state.c.history.push_back(action.clas)
def transition_states(self, states, float[:, ::1] scores):
assert len(states) == scores.shape[0]
cdef StateClass state
cdef float* c_scores = &scores[0, 0]
cdef vector[StateC*] c_states
for state in states:
c_states.push_back(state.c)
c_transition_batch(self, &c_states[0], c_scores, scores.shape[1], scores.shape[0])
return [state for state in states if not state.c.is_final()]
cdef Transition lookup_transition(self, object name) except *: cdef Transition lookup_transition(self, object name) except *:
raise NotImplementedError raise NotImplementedError
@ -250,3 +262,30 @@ cdef class TransitionSystem:
msg = util.from_bytes(bytes_data, deserializers, exclude) msg = util.from_bytes(bytes_data, deserializers, exclude)
self.initialize_actions(labels) self.initialize_actions(labels)
return self return self
cdef void c_transition_batch(TransitionSystem moves, StateC** states, const float* scores,
int nr_class, int batch_size) nogil:
is_valid = <int*>calloc(moves.n_moves, sizeof(int))
cdef int i, guess
cdef Transition action
for i in range(batch_size):
moves.set_valid(is_valid, states[i])
guess = arg_max_if_valid(&scores[i*nr_class], is_valid, nr_class)
if guess == -1:
# This shouldn't happen, but it's hard to raise an error here,
# and we don't want to infinite loop. So, force to end state.
states[i].force_final()
else:
action = moves.c[guess]
action.do(states[i], action.label)
free(is_valid)
cdef int arg_max_if_valid(const weight_t* scores, const int* is_valid, int n) nogil:
cdef int best = -1
for i in range(n):
if is_valid[i] >= 1:
if best == -1 or scores[i] > scores[best]:
best = i
return best

View File

@ -92,8 +92,9 @@ class Parser(TrainablePipe):
@property @property
def move_names(self): def move_names(self):
names = [] names = []
cdef TransitionSystem moves = self.moves
for i in range(self.moves.n_moves): for i in range(self.moves.n_moves):
name = self.moves.move_name(self.moves.c[i].move, self.moves.c[i].label) name = self.moves.move_name(moves.c[i].move, moves.c[i].label)
# Explicitly removing the internal "U-" token used for blocking entities # Explicitly removing the internal "U-" token used for blocking entities
if name != "U-": if name != "U-":
names.append(name) names.append(name)
@ -273,14 +274,14 @@ class Parser(TrainablePipe):
cdef TransitionSystem moves = self.moves cdef TransitionSystem moves = self.moves
cdef StateClass state cdef StateClass state
cdef int clas cdef int clas
cdef int nF = self.model.state2vec.nF cdef int nF = self.model.get_dim("nF")
cdef int nO = moves.n_moves cdef int nO = moves.n_moves
cdef int nS = sum([len(history) for history in histories]) cdef int nS = sum([len(history) for history in histories])
cdef np.ndarray costs = numpy.zeros((nS, nO), dtype="f") cdef np.ndarray costs = numpy.zeros((nS, nO), dtype="f")
cdef Pool mem = Pool() cdef Pool mem = Pool()
is_valid = <int*>mem.alloc(nO, sizeof(int)) is_valid = <int*>mem.alloc(nO, sizeof(int))
c_costs = <float*>costs.data c_costs = <float*>costs.data
states = moves.init_states([eg.x for eg in examples]) states = moves.init_batch([eg.x for eg in examples])
cdef int i = 0 cdef int i = 0
for eg, state, history in zip(examples, states, histories): for eg, state, history in zip(examples, states, histories):
gold = moves.init_gold(state, eg) gold = moves.init_gold(state, eg)
@ -342,7 +343,7 @@ class Parser(TrainablePipe):
for example in islice(get_examples(), 10): for example in islice(get_examples(), 10):
doc_sample.append(example.predicted) doc_sample.append(example.predicted)
assert len(doc_sample) > 0, Errors.E923.format(name=self.name) assert len(doc_sample) > 0, Errors.E923.format(name=self.name)
self.model.initialize(doc_sample) self.model.initialize((doc_sample, self.moves))
if nlp is not None: if nlp is not None:
self.init_multitask_objectives(get_examples, nlp.pipeline) self.init_multitask_objectives(get_examples, nlp.pipeline)

View File

@ -5,8 +5,6 @@ from pathlib import Path
from spacy.about import __version__ as spacy_version from spacy.about import __version__ as spacy_version
from spacy import util from spacy import util
from spacy import prefer_gpu, require_gpu, require_cpu from spacy import prefer_gpu, require_gpu, require_cpu
from spacy.ml._precomputable_affine import PrecomputableAffine
from spacy.ml._precomputable_affine import _backprop_precomputable_affine_padding
from spacy.util import dot_to_object, SimpleFrozenList from spacy.util import dot_to_object, SimpleFrozenList
from thinc.api import Config, Optimizer, ConfigValidationError from thinc.api import Config, Optimizer, ConfigValidationError
from spacy.training.batchers import minibatch_by_words from spacy.training.batchers import minibatch_by_words