diff --git a/spacy/_ml.py b/spacy/_ml.py index 03f9ca874..95341e747 100644 --- a/spacy/_ml.py +++ b/spacy/_ml.py @@ -25,6 +25,7 @@ from thinc.describe import Dimension, Synapses, Biases, Gradient from thinc.neural._classes.affine import _set_dimensions_if_needed import thinc.extra.load_nlp +from .strings import get_string_id from .attrs import ID, ORTH, LOWER, NORM, PREFIX, SUFFIX, SHAPE from .errors import Errors, user_warning, Warnings from . import util @@ -468,10 +469,12 @@ def Tok2Vec(width, embed_size, **kwargs): tok2vec = ( FeatureExtracter(cols) - >> with_flatten( + >> with_bos_eos( + with_flatten( embed >> CNN(width, conv_depth, cnn_maxout_pieces, nW=conv_window), - pad=conv_depth * conv_window) + pad=conv_depth * conv_window) + ) ) if bilstm_depth >= 1: @@ -488,6 +491,42 @@ def Tok2Vec(width, embed_size, **kwargs): return tok2vec +def with_bos_eos(layer): + bos = get_string_id("-bos-") + eos = get_string_id("-eos-") + def bos_eos_forward(feats, drop=0.): + if not feats: + return layer.begin_update(feats, drop=drop) + xp = get_array_module(feats[0]) + size = feats[0].shape[1] + dtype = feats[0].dtype + bos_row = xp.array([bos] * size, dtype=dtype) + eos_row = xp.array([eos] * size, dtype=dtype) + + # Add the bos and eos rows + feats = [xp.vstack((bos_row, f, eos_row)) for f in feats] + # Pass through the network + outputs, backprop_output = layer.begin_update(feats, drop=drop) + # Remove the rows for the bos and eos + outputs = [out[1:-1] for out in outputs] + + def bos_eos_backward(d_outputs, sgd=None): + if not d_outputs: + return backprop_output(d_outputs, sgd=sgd) + # Get empty row + xp = get_array_module(d_outputs[0]) + empty = xp.zeros((d_outputs[0].shape[1],), dtype=d_outputs[0].dtype) + d_outputs = [xp.vstack((empty, d_o, empty)) for d_o in d_outputs] + d_feats = backprop_output(d_outputs, sgd=sgd) + if d_feats is None: + return None + else: + return [d_f[1:-1] for d_f in d_feats] + return outputs, bos_eos_backward + return wrap(bos_eos_forward, layer) + + + def reapply(layer, n_times): def reapply_fwd(X, drop=0.0): backprops = []