diff --git a/spacy/_ml.py b/spacy/_ml.py index 6104324ab..86dac6c7a 100644 --- a/spacy/_ml.py +++ b/spacy/_ml.py @@ -953,16 +953,24 @@ class CharacterEmbed(Model): return output, backprop_character_embed -def get_cossim_loss(yh, y): +def get_cossim_loss(yh, y, ignore_zeros=False): + xp = get_array_module(yh) + # Find the zero vectors + if ignore_zeros: + zero_indices = xp.abs(y).sum(axis=1) == 0 # Add a small constant to avoid 0 vectors yh = yh + 1e-8 y = y + 1e-8 # https://math.stackexchange.com/questions/1923613/partial-derivative-of-cosine-similarity - xp = get_array_module(yh) norm_yh = xp.linalg.norm(yh, axis=1, keepdims=True) norm_y = xp.linalg.norm(y, axis=1, keepdims=True) mul_norms = norm_yh * norm_y cosine = (yh * y).sum(axis=1, keepdims=True) / mul_norms d_yh = (y / mul_norms) - (cosine * (yh / norm_yh ** 2)) - loss = xp.abs(cosine - 1).sum() + losses = xp.abs(cosine - 1) + if ignore_zeros: + # If the target was a zero vector, don't count it in the loss. + d_yh[zero_indices] = 0 + losses[zero_indices] = 0 + loss = losses.sum() return loss, -d_yh diff --git a/spacy/cli/pretrain.py b/spacy/cli/pretrain.py index 60f703d2f..891e15fa2 100644 --- a/spacy/cli/pretrain.py +++ b/spacy/cli/pretrain.py @@ -35,6 +35,7 @@ from .train import _load_pretrained_tok2vec output_dir=("Directory to write models to on each epoch", "positional", None, str), width=("Width of CNN layers", "option", "cw", int), depth=("Depth of CNN layers", "option", "cd", int), + bilstm_depth=("Depth of BiLSTM layers (requires PyTorch)", "option", "lstm", int), embed_rows=("Number of embedding rows", "option", "er", int), loss_func=( "Loss function to use for the objective. Either 'L2' or 'cosine'", @@ -80,6 +81,7 @@ def pretrain( output_dir, width=96, depth=4, + bilstm_depth=2, embed_rows=2000, loss_func="cosine", use_vectors=False, @@ -116,6 +118,10 @@ def pretrain( util.fix_random_seed(seed) has_gpu = prefer_gpu() + if has_gpu: + import torch + + torch.set_default_tensor_type("torch.cuda.FloatTensor") msg.info("Using GPU" if has_gpu else "Not using GPU") output_dir = Path(output_dir) @@ -151,7 +157,7 @@ def pretrain( embed_rows, conv_depth=depth, pretrained_vectors=pretrained_vectors, - bilstm_depth=0, # Requires PyTorch. Experimental. + bilstm_depth=bilstm_depth, # Requires PyTorch. Experimental. cnn_maxout_pieces=3, # You can try setting this higher subword_features=True, # Set to False for Chinese etc ), diff --git a/spacy/pipeline/pipes.pyx b/spacy/pipeline/pipes.pyx index 23509fcae..63ab09e56 100644 --- a/spacy/pipeline/pipes.pyx +++ b/spacy/pipeline/pipes.pyx @@ -29,7 +29,7 @@ from .._ml import Tok2Vec, build_tagger_model, cosine, get_cossim_loss from .._ml import build_text_classifier, build_simple_cnn_text_classifier from .._ml import build_bow_text_classifier, build_nel_encoder from .._ml import link_vectors_to_models, zero_init, flatten -from .._ml import masked_language_model, create_default_optimizer +from .._ml import masked_language_model, create_default_optimizer, get_cossim_loss from ..errors import Errors, TempErrors, user_warning, Warnings from .. import util @@ -880,8 +880,7 @@ class ClozeMultitask(Pipe): # and look them up all at once. This prevents data copying. ids = self.model.ops.flatten([doc.to_array(ID).ravel() for doc in docs]) target = vectors[ids] - gradient = (prediction - target) / prediction.shape[0] - loss = (gradient**2).sum() + loss, gradient = get_cossim_loss(prediction, target, ignore_zeros=True) return float(loss), gradient def update(self, docs, golds, drop=0., sgd=None, losses=None):