Add cnn_window option to pretrain

This commit is contained in:
Matthw Honnibal 2019-10-20 17:46:34 +02:00
parent 3a67aa857e
commit 5a601ef46a

View File

@ -19,7 +19,7 @@ from ..errors import Errors
from ..tokens import Doc
from ..attrs import ID, HEAD
from .._ml import Tok2Vec, flatten, chain, create_default_optimizer
from .._ml import masked_language_model, get_cossim_loss
from .._ml import masked_language_model, get_cossim_loss, get_characters_loss
from .._ml import MultiSoftmax
from .. import util
from .train import _load_pretrained_tok2vec
@ -37,6 +37,7 @@ from .train import _load_pretrained_tok2vec
output_dir=("Directory to write models to on each epoch", "positional", None, str),
width=("Width of CNN layers", "option", "cw", int),
depth=("Depth of CNN layers", "option", "cd", int),
cnn_window=("Window size for CNN layers", "option", "cW", int),
use_chars=("Whether to use character-based embedding", "flag", "chr", bool),
sa_depth=("Depth of self-attention layers", "option", "sa", int),
bilstm_depth=("Depth of BiLSTM layers (requires PyTorch)", "option", "lstm", int),
@ -88,6 +89,7 @@ def pretrain(
bilstm_depth=0,
sa_depth=0,
use_chars=False,
cnn_window=1,
embed_rows=2000,
loss_func="cosine",
use_vectors=False,
@ -158,6 +160,7 @@ def pretrain(
width,
embed_rows,
conv_depth=depth,
conv_window=cnn_window,
pretrained_vectors=pretrained_vectors,
char_embed=use_chars,
self_attn_depth=sa_depth, # Experimental.
@ -297,16 +300,6 @@ def make_docs(nlp, batch, min_length, max_length):
return docs, skip_count
def get_characters_loss(ops, docs, prediction, nr_char=10):
target_ids = numpy.vstack([doc.to_utf8_array(nr_char=nr_char) for doc in docs])
target_ids = target_ids.reshape((-1,))
target = ops.asarray(to_categorical(target_ids, nb_classes=256), dtype="f")
target = target.reshape((-1, 256*nr_char))
diff = prediction - target
loss = (diff**2).sum()
d_target = diff / float(prediction.shape[0])
return loss, d_target
def get_vectors_loss(ops, docs, prediction, objective="L2"):
"""Compute a mean-squared error loss between the documents' vectors and