diff --git a/spacy/cli/pretrain.py b/spacy/cli/pretrain.py index 303fde105..fb2333b7d 100644 --- a/spacy/cli/pretrain.py +++ b/spacy/cli/pretrain.py @@ -19,7 +19,7 @@ from ..errors import Errors from ..tokens import Doc from ..attrs import ID, HEAD from .._ml import Tok2Vec, flatten, chain, create_default_optimizer -from .._ml import masked_language_model, get_cossim_loss +from .._ml import masked_language_model, get_cossim_loss, get_characters_loss from .._ml import MultiSoftmax from .. import util from .train import _load_pretrained_tok2vec @@ -37,6 +37,7 @@ from .train import _load_pretrained_tok2vec output_dir=("Directory to write models to on each epoch", "positional", None, str), width=("Width of CNN layers", "option", "cw", int), depth=("Depth of CNN layers", "option", "cd", int), + cnn_window=("Window size for CNN layers", "option", "cW", int), use_chars=("Whether to use character-based embedding", "flag", "chr", bool), sa_depth=("Depth of self-attention layers", "option", "sa", int), bilstm_depth=("Depth of BiLSTM layers (requires PyTorch)", "option", "lstm", int), @@ -88,6 +89,7 @@ def pretrain( bilstm_depth=0, sa_depth=0, use_chars=False, + cnn_window=1, embed_rows=2000, loss_func="cosine", use_vectors=False, @@ -158,6 +160,7 @@ def pretrain( width, embed_rows, conv_depth=depth, + conv_window=cnn_window, pretrained_vectors=pretrained_vectors, char_embed=use_chars, self_attn_depth=sa_depth, # Experimental. @@ -297,16 +300,6 @@ def make_docs(nlp, batch, min_length, max_length): return docs, skip_count -def get_characters_loss(ops, docs, prediction, nr_char=10): - target_ids = numpy.vstack([doc.to_utf8_array(nr_char=nr_char) for doc in docs]) - target_ids = target_ids.reshape((-1,)) - target = ops.asarray(to_categorical(target_ids, nb_classes=256), dtype="f") - target = target.reshape((-1, 256*nr_char)) - diff = prediction - target - loss = (diff**2).sum() - d_target = diff / float(prediction.shape[0]) - return loss, d_target - def get_vectors_loss(ops, docs, prediction, objective="L2"): """Compute a mean-squared error loss between the documents' vectors and