mirror of
https://github.com/explosion/spaCy.git
synced 2025-03-12 15:25:47 +03:00
Work through tests
This commit is contained in:
parent
d765a4f8ee
commit
c538eaf1c8
|
@ -1,6 +1,7 @@
|
||||||
from typing import List, Tuple, Any, Optional
|
from typing import List, Tuple, Any, Optional
|
||||||
from thinc.api import Ops, Model, normal_init, chain, list2array, Linear
|
from thinc.api import Ops, Model, normal_init, chain, list2array, Linear
|
||||||
from thinc.types import Floats1d, Floats2d, Floats3d, Ints2d, Floats4d
|
from thinc.types import Floats1d, Floats2d, Floats3d, Ints2d, Floats4d
|
||||||
|
import numpy
|
||||||
from ..tokens.doc import Doc
|
from ..tokens.doc import Doc
|
||||||
|
|
||||||
|
|
||||||
|
@ -29,7 +30,7 @@ def TransitionModel(
|
||||||
forward=forward,
|
forward=forward,
|
||||||
init=init,
|
init=init,
|
||||||
layers=[tok2vec_projected],
|
layers=[tok2vec_projected],
|
||||||
refs={"tok2vec": tok2vec},
|
refs={"tok2vec": tok2vec_projected},
|
||||||
params={
|
params={
|
||||||
"lower_W": None, # Floats2d W for the hidden layer
|
"lower_W": None, # Floats2d W for the hidden layer
|
||||||
"lower_b": None, # Floats1d bias for the hidden layer
|
"lower_b": None, # Floats1d bias for the hidden layer
|
||||||
|
@ -77,8 +78,10 @@ def init(
|
||||||
Y: Optional[Tuple[List[State], List[Floats2d]]] = None,
|
Y: Optional[Tuple[List[State], List[Floats2d]]] = None,
|
||||||
):
|
):
|
||||||
if X is not None:
|
if X is not None:
|
||||||
docs, states = X
|
docs, moves = X
|
||||||
model.get_ref("tok2vec").initialize(X=docs)
|
model.get_ref("tok2vec").initialize(X=docs)
|
||||||
|
else:
|
||||||
|
model.get_ref("tok2vec").initialize()
|
||||||
inferred_nO = _infer_nO(Y)
|
inferred_nO = _infer_nO(Y)
|
||||||
if inferred_nO is not None:
|
if inferred_nO is not None:
|
||||||
current_nO = model.maybe_get_dim("nO")
|
current_nO = model.maybe_get_dim("nO")
|
||||||
|
@ -110,7 +113,8 @@ def init(
|
||||||
_lsuv_init(model)
|
_lsuv_init(model)
|
||||||
|
|
||||||
|
|
||||||
def forward(model, docs_moves, is_train):
|
def forward(model, docs_moves: Tuple[List[Doc], TransitionSystem], is_train: bool):
|
||||||
|
nF = model.get_dim("nF")
|
||||||
tok2vec = model.get_ref("tok2vec")
|
tok2vec = model.get_ref("tok2vec")
|
||||||
lower_pad = model.get_param("lower_pad")
|
lower_pad = model.get_param("lower_pad")
|
||||||
lower_b = model.get_param("lower_b")
|
lower_b = model.get_param("lower_b")
|
||||||
|
@ -126,13 +130,16 @@ def forward(model, docs_moves, is_train):
|
||||||
all_which = []
|
all_which = []
|
||||||
all_statevecs = []
|
all_statevecs = []
|
||||||
all_scores = []
|
all_scores = []
|
||||||
next_states = list(states)
|
next_states = [s for s in states if not s.is_final()]
|
||||||
unseen_mask = _get_unseen_mask(model)
|
unseen_mask = _get_unseen_mask(model)
|
||||||
|
ids = numpy.zeros((len(states), nF), dtype="i")
|
||||||
while next_states:
|
while next_states:
|
||||||
ids = moves.get_state_ids(states)
|
ids = ids[: len(next_states)]
|
||||||
|
for i, state in enumerate(next_states):
|
||||||
|
state.set_context_tokens(ids, i, nF)
|
||||||
# Sum the state features, add the bias and apply the activation (maxout)
|
# Sum the state features, add the bias and apply the activation (maxout)
|
||||||
# to create the state vectors.
|
# to create the state vectors.
|
||||||
preacts = _sum_state_features(feats, lower_pad, ids)
|
preacts = _sum_state_features(ops, feats, ids)
|
||||||
preacts += lower_b
|
preacts += lower_b
|
||||||
statevecs, which = ops.maxout(preacts)
|
statevecs, which = ops.maxout(preacts)
|
||||||
# Multiply the state-vector by the scores weights and add the bias,
|
# Multiply the state-vector by the scores weights and add the bias,
|
||||||
|
@ -141,7 +148,7 @@ def forward(model, docs_moves, is_train):
|
||||||
scores += upper_b
|
scores += upper_b
|
||||||
scores[:, unseen_mask == 0] = model.ops.xp.nanmin(scores)
|
scores[:, unseen_mask == 0] = model.ops.xp.nanmin(scores)
|
||||||
# Transition the states, filtering out any that are finished.
|
# Transition the states, filtering out any that are finished.
|
||||||
next_states = moves.transition_states(states, scores)
|
next_states = moves.transition_states(next_states, scores)
|
||||||
all_scores.append(scores)
|
all_scores.append(scores)
|
||||||
if is_train:
|
if is_train:
|
||||||
# Remember intermediate results for the backprop.
|
# Remember intermediate results for the backprop.
|
||||||
|
@ -204,24 +211,23 @@ def _sum_state_features(ops: Ops, feats: Floats3d, ids: Ints2d, _arange=[]) -> F
|
||||||
def _forward_precomputable_affine(model, X: Floats2d, is_train: bool):
|
def _forward_precomputable_affine(model, X: Floats2d, is_train: bool):
|
||||||
|
|
||||||
W: Floats4d = model.get_param("lower_W")
|
W: Floats4d = model.get_param("lower_W")
|
||||||
b: Floats2d = model.get_param("lower_b")
|
|
||||||
pad: Floats4d = model.get_param("lower_pad")
|
pad: Floats4d = model.get_param("lower_pad")
|
||||||
nF = model.get_dim("nF")
|
nF = model.get_dim("nF")
|
||||||
nO = model.get_dim("nO")
|
nH = model.get_dim("nH")
|
||||||
nP = model.get_dim("nP")
|
nP = model.get_dim("nP")
|
||||||
nI = model.get_dim("nI")
|
nI = model.get_dim("nI")
|
||||||
Yf_ = model.ops.gemm(X, model.ops.reshape2f(W, nF * nO * nP, nI), trans2=True)
|
Yf_ = model.ops.gemm(X, model.ops.reshape2f(W, nF * nH * nP, nI), trans2=True)
|
||||||
Yf = model.ops.reshape4f(Yf_, Yf_.shape[0], nF, nO, nP)
|
Yf = model.ops.reshape4f(Yf_, Yf_.shape[0], nF, nH, nP)
|
||||||
Yf = model.ops.xp.vstack((Yf, pad))
|
Yf = model.ops.xp.vstack((Yf, pad))
|
||||||
|
|
||||||
def backward(dY_ids: Tuple[Floats3d, Ints2d]):
|
def backward(dY_ids: Tuple[Floats3d, Ints2d]):
|
||||||
# This backprop is particularly tricky, because we get back a different
|
# This backprop is particularly tricky, because we get back a different
|
||||||
# thing from what we put out. We put out an array of shape:
|
# thing from what we put out. We put out an array of shape:
|
||||||
# (nB, nF, nO, nP), and get back:
|
# (nB, nF, nH, nP), and get back:
|
||||||
# (nB, nO, nP) and ids (nB, nF)
|
# (nB, nH, nP) and ids (nB, nF)
|
||||||
# The ids tell us the values of nF, so we would have:
|
# The ids tell us the values of nF, so we would have:
|
||||||
#
|
#
|
||||||
# dYf = zeros((nB, nF, nO, nP))
|
# dYf = zeros((nB, nF, nH, nP))
|
||||||
# for b in range(nB):
|
# for b in range(nB):
|
||||||
# for f in range(nF):
|
# for f in range(nF):
|
||||||
# dYf[b, ids[b, f]] += dY[b]
|
# dYf[b, ids[b, f]] += dY[b]
|
||||||
|
@ -230,7 +236,7 @@ def _forward_precomputable_affine(model, X: Floats2d, is_train: bool):
|
||||||
# in the indices.
|
# in the indices.
|
||||||
dY, ids = dY_ids
|
dY, ids = dY_ids
|
||||||
assert dY.ndim == 3
|
assert dY.ndim == 3
|
||||||
assert dY.shape[1] == nO, dY.shape
|
assert dY.shape[1] == nH, dY.shape
|
||||||
assert dY.shape[2] == nP, dY.shape
|
assert dY.shape[2] == nP, dY.shape
|
||||||
# nB = dY.shape[0]
|
# nB = dY.shape[0]
|
||||||
model.inc_grad(
|
model.inc_grad(
|
||||||
|
@ -239,14 +245,14 @@ def _forward_precomputable_affine(model, X: Floats2d, is_train: bool):
|
||||||
Xf = model.ops.reshape2f(X[ids], ids.shape[0], nF * nI)
|
Xf = model.ops.reshape2f(X[ids], ids.shape[0], nF * nI)
|
||||||
|
|
||||||
model.inc_grad("lower_b", dY.sum(axis=0)) # type: ignore
|
model.inc_grad("lower_b", dY.sum(axis=0)) # type: ignore
|
||||||
dY = model.ops.reshape2f(dY, dY.shape[0], nO * nP)
|
dY = model.ops.reshape2f(dY, dY.shape[0], nH * nP)
|
||||||
|
|
||||||
Wopfi = W.transpose((1, 2, 0, 3))
|
Wopfi = W.transpose((1, 2, 0, 3))
|
||||||
Wopfi = Wopfi.reshape((nO * nP, nF * nI))
|
Wopfi = Wopfi.reshape((nH * nP, nF * nI))
|
||||||
dXf = model.ops.gemm(dY.reshape((dY.shape[0], nO * nP)), Wopfi)
|
dXf = model.ops.gemm(dY.reshape((dY.shape[0], nH * nP)), Wopfi)
|
||||||
|
|
||||||
dWopfi = model.ops.gemm(dY, Xf, trans1=True)
|
dWopfi = model.ops.gemm(dY, Xf, trans1=True)
|
||||||
dWopfi = dWopfi.reshape((nO, nP, nF, nI))
|
dWopfi = dWopfi.reshape((nH, nP, nF, nI))
|
||||||
# (o, p, f, i) --> (f, o, p, i)
|
# (o, p, f, i) --> (f, o, p, i)
|
||||||
dWopfi = dWopfi.transpose((2, 0, 1, 3))
|
dWopfi = dWopfi.transpose((2, 0, 1, 3))
|
||||||
model.inc_grad("W", dWopfi)
|
model.inc_grad("W", dWopfi)
|
||||||
|
|
|
@ -180,3 +180,6 @@ cdef class StateClass:
|
||||||
|
|
||||||
def clone(self, StateClass src):
|
def clone(self, StateClass src):
|
||||||
self.c.clone(src.c)
|
self.c.clone(src.c)
|
||||||
|
|
||||||
|
def set_context_tokens(self, int[:, :] output, int row, int n_feats):
|
||||||
|
self.c.set_context_tokens(&output[row, 0], n_feats)
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
# cython: infer_types=True
|
# cython: infer_types=True
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
from cymem.cymem cimport Pool
|
from cymem.cymem cimport Pool
|
||||||
|
from libc.stdlib cimport calloc, free
|
||||||
|
from libcpp.vector cimport vector
|
||||||
|
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
import srsly
|
import srsly
|
||||||
|
@ -141,6 +143,16 @@ cdef class TransitionSystem:
|
||||||
action.do(state.c, action.label)
|
action.do(state.c, action.label)
|
||||||
state.c.history.push_back(action.clas)
|
state.c.history.push_back(action.clas)
|
||||||
|
|
||||||
|
def transition_states(self, states, float[:, ::1] scores):
|
||||||
|
assert len(states) == scores.shape[0]
|
||||||
|
cdef StateClass state
|
||||||
|
cdef float* c_scores = &scores[0, 0]
|
||||||
|
cdef vector[StateC*] c_states
|
||||||
|
for state in states:
|
||||||
|
c_states.push_back(state.c)
|
||||||
|
c_transition_batch(self, &c_states[0], c_scores, scores.shape[1], scores.shape[0])
|
||||||
|
return [state for state in states if not state.c.is_final()]
|
||||||
|
|
||||||
cdef Transition lookup_transition(self, object name) except *:
|
cdef Transition lookup_transition(self, object name) except *:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@ -250,3 +262,30 @@ cdef class TransitionSystem:
|
||||||
msg = util.from_bytes(bytes_data, deserializers, exclude)
|
msg = util.from_bytes(bytes_data, deserializers, exclude)
|
||||||
self.initialize_actions(labels)
|
self.initialize_actions(labels)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
cdef void c_transition_batch(TransitionSystem moves, StateC** states, const float* scores,
|
||||||
|
int nr_class, int batch_size) nogil:
|
||||||
|
is_valid = <int*>calloc(moves.n_moves, sizeof(int))
|
||||||
|
cdef int i, guess
|
||||||
|
cdef Transition action
|
||||||
|
for i in range(batch_size):
|
||||||
|
moves.set_valid(is_valid, states[i])
|
||||||
|
guess = arg_max_if_valid(&scores[i*nr_class], is_valid, nr_class)
|
||||||
|
if guess == -1:
|
||||||
|
# This shouldn't happen, but it's hard to raise an error here,
|
||||||
|
# and we don't want to infinite loop. So, force to end state.
|
||||||
|
states[i].force_final()
|
||||||
|
else:
|
||||||
|
action = moves.c[guess]
|
||||||
|
action.do(states[i], action.label)
|
||||||
|
free(is_valid)
|
||||||
|
|
||||||
|
|
||||||
|
cdef int arg_max_if_valid(const weight_t* scores, const int* is_valid, int n) nogil:
|
||||||
|
cdef int best = -1
|
||||||
|
for i in range(n):
|
||||||
|
if is_valid[i] >= 1:
|
||||||
|
if best == -1 or scores[i] > scores[best]:
|
||||||
|
best = i
|
||||||
|
return best
|
||||||
|
|
|
@ -92,8 +92,9 @@ class Parser(TrainablePipe):
|
||||||
@property
|
@property
|
||||||
def move_names(self):
|
def move_names(self):
|
||||||
names = []
|
names = []
|
||||||
|
cdef TransitionSystem moves = self.moves
|
||||||
for i in range(self.moves.n_moves):
|
for i in range(self.moves.n_moves):
|
||||||
name = self.moves.move_name(self.moves.c[i].move, self.moves.c[i].label)
|
name = self.moves.move_name(moves.c[i].move, moves.c[i].label)
|
||||||
# Explicitly removing the internal "U-" token used for blocking entities
|
# Explicitly removing the internal "U-" token used for blocking entities
|
||||||
if name != "U-":
|
if name != "U-":
|
||||||
names.append(name)
|
names.append(name)
|
||||||
|
@ -273,14 +274,14 @@ class Parser(TrainablePipe):
|
||||||
cdef TransitionSystem moves = self.moves
|
cdef TransitionSystem moves = self.moves
|
||||||
cdef StateClass state
|
cdef StateClass state
|
||||||
cdef int clas
|
cdef int clas
|
||||||
cdef int nF = self.model.state2vec.nF
|
cdef int nF = self.model.get_dim("nF")
|
||||||
cdef int nO = moves.n_moves
|
cdef int nO = moves.n_moves
|
||||||
cdef int nS = sum([len(history) for history in histories])
|
cdef int nS = sum([len(history) for history in histories])
|
||||||
cdef np.ndarray costs = numpy.zeros((nS, nO), dtype="f")
|
cdef np.ndarray costs = numpy.zeros((nS, nO), dtype="f")
|
||||||
cdef Pool mem = Pool()
|
cdef Pool mem = Pool()
|
||||||
is_valid = <int*>mem.alloc(nO, sizeof(int))
|
is_valid = <int*>mem.alloc(nO, sizeof(int))
|
||||||
c_costs = <float*>costs.data
|
c_costs = <float*>costs.data
|
||||||
states = moves.init_states([eg.x for eg in examples])
|
states = moves.init_batch([eg.x for eg in examples])
|
||||||
cdef int i = 0
|
cdef int i = 0
|
||||||
for eg, state, history in zip(examples, states, histories):
|
for eg, state, history in zip(examples, states, histories):
|
||||||
gold = moves.init_gold(state, eg)
|
gold = moves.init_gold(state, eg)
|
||||||
|
@ -342,7 +343,7 @@ class Parser(TrainablePipe):
|
||||||
for example in islice(get_examples(), 10):
|
for example in islice(get_examples(), 10):
|
||||||
doc_sample.append(example.predicted)
|
doc_sample.append(example.predicted)
|
||||||
assert len(doc_sample) > 0, Errors.E923.format(name=self.name)
|
assert len(doc_sample) > 0, Errors.E923.format(name=self.name)
|
||||||
self.model.initialize(doc_sample)
|
self.model.initialize((doc_sample, self.moves))
|
||||||
if nlp is not None:
|
if nlp is not None:
|
||||||
self.init_multitask_objectives(get_examples, nlp.pipeline)
|
self.init_multitask_objectives(get_examples, nlp.pipeline)
|
||||||
|
|
||||||
|
|
|
@ -5,8 +5,6 @@ from pathlib import Path
|
||||||
from spacy.about import __version__ as spacy_version
|
from spacy.about import __version__ as spacy_version
|
||||||
from spacy import util
|
from spacy import util
|
||||||
from spacy import prefer_gpu, require_gpu, require_cpu
|
from spacy import prefer_gpu, require_gpu, require_cpu
|
||||||
from spacy.ml._precomputable_affine import PrecomputableAffine
|
|
||||||
from spacy.ml._precomputable_affine import _backprop_precomputable_affine_padding
|
|
||||||
from spacy.util import dot_to_object, SimpleFrozenList
|
from spacy.util import dot_to_object, SimpleFrozenList
|
||||||
from thinc.api import Config, Optimizer, ConfigValidationError
|
from thinc.api import Config, Optimizer, ConfigValidationError
|
||||||
from spacy.training.batchers import minibatch_by_words
|
from spacy.training.batchers import minibatch_by_words
|
||||||
|
|
Loading…
Reference in New Issue
Block a user