From ffe0451d0972ec209556dc7aad356deca1cbe0a7 Mon Sep 17 00:00:00 2001 From: svlandeg Date: Wed, 3 Jun 2020 14:45:00 +0200 Subject: [PATCH] pretrain from config --- examples/experiments/onto-joint/pretrain.cfg | 144 +++++++++++++++ spacy/_ml.py | 0 spacy/cli/pretrain.py | 179 +++++++------------ spacy/errors.py | 2 - spacy/ml/models/multi_task.py | 84 ++++++++- 5 files changed, 286 insertions(+), 123 deletions(-) create mode 100644 examples/experiments/onto-joint/pretrain.cfg delete mode 100644 spacy/_ml.py diff --git a/examples/experiments/onto-joint/pretrain.cfg b/examples/experiments/onto-joint/pretrain.cfg new file mode 100644 index 000000000..6a41cc677 --- /dev/null +++ b/examples/experiments/onto-joint/pretrain.cfg @@ -0,0 +1,144 @@ +# Training hyper-parameters and additional features. +[training] +# Whether to train on sequences with 'gold standard' sentence boundaries +# and tokens. If you set this to true, take care to ensure your run-time +# data is passed in sentence-by-sentence via some prior preprocessing. +gold_preproc = false +# Limitations on training document length or number of examples. +max_length = 0 +limit = 0 +# Data augmentation +orth_variant_level = 0.0 +dropout = 0.1 +# Controls early-stopping. 0 or -1 mean unlimited. +patience = 1600 +max_epochs = 0 +max_steps = 20000 +eval_frequency = 400 +# Other settings +seed = 0 +accumulate_gradient = 1 +use_pytorch_for_gpu_memory = false +# Control how scores are printed and checkpoints are evaluated. +scores = ["speed", "tags_acc", "uas", "las", "ents_f"] +score_weights = {"las": 0.4, "ents_f": 0.4, "tags_acc": 0.2} +# These settings are invalid for the transformer models. +init_tok2vec = null +vectors = null +discard_oversize = false + +[training.batch_size] +@schedules = "compounding.v1" +start = 1000 +stop = 1000 +compound = 1.001 + +[training.optimizer] +@optimizers = "Adam.v1" +beta1 = 0.9 +beta2 = 0.999 +L2_is_weight_decay = true +L2 = 0.01 +grad_clip = 1.0 +use_averages = true +eps = 1e-8 +learn_rate = 0.001 + +[pretraining] +max_epochs = 100 +min_length = 5 +max_length = 500 +dropout = 0.2 +n_save_every = null +batch_size = 3000 + +[pretraining.model] +@architectures = "spacy.HashEmbedCNN.v1" +pretrained_vectors = ${nlp:vectors} +width = 256 +depth = 6 +window_size = 1 +embed_size = 2000 +maxout_pieces = 3 +subword_features = true +dropout = null + +[pretraining.optimizer] +@optimizers = "Adam.v1" +beta1 = 0.9 +beta2 = 0.999 +L2_is_weight_decay = true +L2 = 0.01 +grad_clip = 1.0 +use_averages = true +eps = 1e-8 +learn_rate = 0.001 + +[pretraining.loss_func] +@losses = "CosineDistance.v1" + +[nlp] +lang = "en" +vectors = ${training:vectors} + +[nlp.pipeline.tok2vec] +factory = "tok2vec" + +[nlp.pipeline.senter] +factory = "senter" + +[nlp.pipeline.ner] +factory = "ner" + +[nlp.pipeline.tagger] +factory = "tagger" + +[nlp.pipeline.parser] +factory = "parser" + +[nlp.pipeline.senter.model] +@architectures = "spacy.Tagger.v1" + +[nlp.pipeline.senter.model.tok2vec] +@architectures = "spacy.Tok2VecTensors.v1" +width = ${nlp.pipeline.tok2vec.model:width} + +[nlp.pipeline.tagger.model] +@architectures = "spacy.Tagger.v1" + +[nlp.pipeline.tagger.model.tok2vec] +@architectures = "spacy.Tok2VecTensors.v1" +width = ${nlp.pipeline.tok2vec.model:width} + +[nlp.pipeline.parser.model] +@architectures = "spacy.TransitionBasedParser.v1" +nr_feature_tokens = 8 +hidden_width = 128 +maxout_pieces = 3 +use_upper = false + +[nlp.pipeline.parser.model.tok2vec] +@architectures = "spacy.Tok2VecTensors.v1" +width = ${nlp.pipeline.tok2vec.model:width} + +[nlp.pipeline.ner.model] +@architectures = "spacy.TransitionBasedParser.v1" +nr_feature_tokens = 3 +hidden_width = 128 +maxout_pieces = 3 +use_upper = false + +[nlp.pipeline.ner.model.tok2vec] +@architectures = "spacy.Tok2VecTensors.v1" +width = ${nlp.pipeline.tok2vec.model:width} + +[nlp.pipeline.tok2vec.model] +@architectures = "spacy.HashEmbedCNN.v1" +pretrained_vectors = ${nlp:vectors} +width = 256 +depth = 6 +window_size = 1 +embed_size = 10000 +maxout_pieces = 3 +subword_features = true +dropout = null diff --git a/spacy/_ml.py b/spacy/_ml.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/spacy/cli/pretrain.py b/spacy/cli/pretrain.py index b2e3229ee..0022a0d07 100644 --- a/spacy/cli/pretrain.py +++ b/spacy/cli/pretrain.py @@ -3,48 +3,36 @@ import numpy import time import re from collections import Counter +import plac from pathlib import Path -from thinc.api import Linear, Maxout, chain, list2array, prefer_gpu -from thinc.api import CosineDistance, L2Distance +from thinc.api import Linear, Maxout, chain, list2array from wasabi import msg import srsly +from thinc.api import use_pytorch_for_gpu_memory -from ..gold import Example from ..errors import Errors from ..ml.models.multi_task import build_masked_language_model from ..tokens import Doc from ..attrs import ID, HEAD -from ..ml.models.tok2vec import build_Tok2Vec_model from .. import util -from ..util import create_default_optimizer -from .train import _load_pretrained_tok2vec +from ..gold import Example -def pretrain( +@plac.annotations( # fmt: off - texts_loc: ("Path to JSONL file with raw texts to learn from, with text provided as the key 'text' or tokens as the key 'tokens'", "positional", None, str), - vectors_model: ("Name or path to spaCy model with vectors to learn from", "positional", None, str), - output_dir: ("Directory to write models to on each epoch", "positional", None, str), - width: ("Width of CNN layers", "option", "cw", int) = 96, - conv_depth: ("Depth of CNN layers", "option", "cd", int) = 4, - bilstm_depth: ("Depth of BiLSTM layers (requires PyTorch)", "option", "lstm", int) = 0, - cnn_pieces: ("Maxout size for CNN layers. 1 for Mish", "option", "cP", int) = 3, - sa_depth: ("Depth of self-attention layers", "option", "sa", int) = 0, - use_chars: ("Whether to use character-based embedding", "flag", "chr", bool) = False, - cnn_window: ("Window size for CNN layers", "option", "cW", int) = 1, - embed_rows: ("Number of embedding rows", "option", "er", int) = 2000, - loss_func: ("Loss function to use for the objective. Either 'L2' or 'cosine'", "option", "L", str) = "cosine", - use_vectors: ("Whether to use the static vectors as input features", "flag", "uv") = False, - dropout: ("Dropout rate", "option", "d", float) = 0.2, - n_iter: ("Number of iterations to pretrain", "option", "i", int) = 1000, - batch_size: ("Number of words per training batch", "option", "bs", int) = 3000, - max_length: ("Max words per example. Longer examples are discarded", "option", "xw", int) = 500, - min_length: ("Min words per example. Shorter examples are discarded", "option", "nw", int) = 5, - seed: ("Seed for random number generators", "option", "s", int) = 0, - n_save_every: ("Save model every X batches.", "option", "se", int) = None, - init_tok2vec: ("Path to pretrained weights for the token-to-vector parts of the models. See 'spacy pretrain'. Experimental.", "option", "t2v", Path) = None, - epoch_start: ("The epoch to start counting at. Only relevant when using '--init-tok2vec' and the given weight file has been renamed. Prevents unintended overwriting of existing weight files.", "option", "es", int) = None, + texts_loc=("Path to JSONL file with raw texts to learn from, with text provided as the key 'text' or tokens as the key 'tokens'", "positional", None, str), + vectors_model=("Name or path to spaCy model with vectors to learn from", "positional", None, str), + config_path=("Path to config file", "positional", None, Path), + output_dir=("Directory to write models to on each epoch", "positional", None, Path), + use_gpu=("Use GPU", "option", "g", int), # fmt: on +) +def pretrain( + texts_loc, + vectors_model, + config_path, + output_dir, + use_gpu=-1, ): """ Pre-train the 'token-to-vector' (tok2vec) layer of pipeline components, @@ -58,23 +46,24 @@ def pretrain( However, it's still quite experimental, so your mileage may vary. To load the weights back in during 'spacy train', you need to ensure - all settings are the same between pretraining and training. The API and - errors around this need some improvement. + all settings are the same between pretraining and training. Ideally, + this is done by using the same config file for both commands. """ - config = dict(locals()) - for key in config: - if isinstance(config[key], Path): - config[key] = str(config[key]) - util.fix_random_seed(seed) + if not config_path or not config_path.exists(): + msg.fail("Config file not found", config_path, exits=1) - has_gpu = prefer_gpu() - if has_gpu: - import torch + if use_gpu >= 0: + msg.info("Using GPU") + util.use_gpu(use_gpu) + else: + msg.info("Using CPU") - torch.set_default_tensor_type("torch.cuda.FloatTensor") - msg.info("Using GPU" if has_gpu else "Not using GPU") + msg.info(f"Loading config from: {config_path}") + config = util.load_config(config_path, create_objects=False) + util.fix_random_seed(config["training"]["seed"]) + if config["training"]["use_pytorch_for_gpu_memory"]: + use_pytorch_for_gpu_memory() - output_dir = Path(output_dir) if output_dir.exists() and [p for p in output_dir.iterdir()]: msg.warn( "Output directory is not empty", @@ -85,7 +74,10 @@ def pretrain( output_dir.mkdir() msg.good(f"Created output directory: {output_dir}") srsly.write_json(output_dir / "config.json", config) - msg.good("Saved settings to config.json") + msg.good("Saved config file in the output directory") + + config = util.load_config(config_path, create_objects=True) + pretrain_config = config["pretraining"] # Load texts from file or stdin if texts_loc != "-": # reading from a file @@ -105,49 +97,11 @@ def pretrain( with msg.loading(f"Loading model '{vectors_model}'..."): nlp = util.load_model(vectors_model) msg.good(f"Loaded model '{vectors_model}'") - pretrained_vectors = None if not use_vectors else nlp.vocab.vectors - model = create_pretraining_model( - nlp, - # TODO: replace with config - build_Tok2Vec_model( - width, - embed_rows, - conv_depth=conv_depth, - pretrained_vectors=pretrained_vectors, - bilstm_depth=bilstm_depth, # Requires PyTorch. Experimental. - subword_features=not use_chars, # Set to False for Chinese etc - maxout_pieces=cnn_pieces, # If set to 1, use Mish activation. - window_size=1, - char_embed=False, - nM=64, - nC=8, - ), - ) - # Load in pretrained weights - if init_tok2vec is not None: - components = _load_pretrained_tok2vec(nlp, init_tok2vec) - msg.text(f"Loaded pretrained tok2vec for: {components}") - # Parse the epoch number from the given weight file - model_name = re.search(r"model\d+\.bin", str(init_tok2vec)) - if model_name: - # Default weight file name so read epoch_start from it by cutting off 'model' and '.bin' - epoch_start = int(model_name.group(0)[5:][:-4]) + 1 - else: - if not epoch_start: - msg.fail( - "You have to use the --epoch-start argument when using a renamed weight file for --init-tok2vec", - exits=True, - ) - elif epoch_start < 0: - msg.fail( - f"The argument --epoch-start has to be greater or equal to 0. {epoch_start} is invalid", - exits=True, - ) - else: - # Without '--init-tok2vec' the '--epoch-start' argument is ignored - epoch_start = 0 + tok2vec = pretrain_config["model"] + model = create_pretraining_model(nlp, tok2vec) + optimizer = pretrain_config["optimizer"] - optimizer = create_default_optimizer() + epoch_start = 0 # TODO tracker = ProgressTracker(frequency=10000) msg.divider(f"Pre-training tok2vec layer - starting at epoch {epoch_start}") row_settings = {"widths": (3, 10, 10, 6, 4), "aligns": ("r", "r", "r", "r", "r")} @@ -168,28 +122,25 @@ def pretrain( file_.write(srsly.json_dumps(log) + "\n") skip_counter = 0 - for epoch in range(epoch_start, n_iter + epoch_start): - for batch_id, batch in enumerate( - util.minibatch_by_words( - (Example(doc=text) for text in texts), size=batch_size - ) - ): + loss_func = pretrain_config["loss_func"] + for epoch in range(epoch_start, pretrain_config["max_epochs"]): + examples = [Example(doc=text) for text in texts] + batches = util.minibatch_by_words(examples, size=pretrain_config["batch_size"]) + for batch_id, batch in enumerate(batches): docs, count = make_docs( nlp, - [text for (text, _) in batch], - max_length=max_length, - min_length=min_length, + [ex.doc for ex in batch], + max_length=pretrain_config["max_length"], + min_length=pretrain_config["min_length"], ) skip_counter += count - loss = make_update( - model, docs, optimizer, objective=loss_func, drop=dropout - ) + loss = make_update(model, docs, optimizer, distance=loss_func) progress = tracker.update(epoch, loss, docs) if progress: msg.row(progress, **row_settings) if texts_loc == "-" and tracker.words_per_epoch[epoch] >= 10 ** 7: break - if n_save_every and (batch_id % n_save_every == 0): + if pretrain_config["n_save_every"] and (batch_id % pretrain_config["n_save_every"] == 0): _save_model(epoch, is_temp=True) _save_model(epoch) tracker.epoch_loss = 0.0 @@ -201,17 +152,17 @@ def pretrain( msg.good("Successfully finished pretrain") -def make_update(model, docs, optimizer, drop=0.0, objective="L2"): +def make_update(model, docs, optimizer, distance): """Perform an update over a single batch of documents. docs (iterable): A batch of `Doc` objects. - drop (float): The dropout rate. optimizer (callable): An optimizer. RETURNS loss: A float for the loss. """ - predictions, backprop = model.begin_update(docs, drop=drop) - loss, gradients = get_vectors_loss(model.ops, docs, predictions, objective) - backprop(gradients, sgd=optimizer) + predictions, backprop = model.begin_update(docs) + loss, gradients = get_vectors_loss(model.ops, docs, predictions, distance) + backprop(gradients) + model.finish_update(optimizer) # Don't want to return a cupy object here # The gradients are modified in-place by the BERT MLM, # so we get an accurate loss @@ -243,12 +194,12 @@ def make_docs(nlp, batch, min_length, max_length): heads = numpy.asarray(heads, dtype="uint64") heads = heads.reshape((len(doc), 1)) doc = doc.from_array([HEAD], heads) - if len(doc) >= min_length and len(doc) < max_length: + if min_length <= len(doc) < max_length: docs.append(doc) return docs, skip_count -def get_vectors_loss(ops, docs, prediction, objective="L2"): +def get_vectors_loss(ops, docs, prediction, distance): """Compute a mean-squared error loss between the documents' vectors and the prediction. @@ -262,13 +213,6 @@ def get_vectors_loss(ops, docs, prediction, objective="L2"): # and look them up all at once. This prevents data copying. ids = ops.flatten([doc.to_array(ID).ravel() for doc in docs]) target = docs[0].vocab.vectors.data[ids] - # TODO: this code originally didn't normalize, but shouldn't normalize=True ? - if objective == "L2": - distance = L2Distance(normalize=False) - elif objective == "cosine": - distance = CosineDistance(normalize=False) - else: - raise ValueError(Errors.E142.format(loss_func=objective)) d_target, loss = distance(prediction, target) return loss, d_target @@ -281,7 +225,7 @@ def create_pretraining_model(nlp, tok2vec): """ output_size = nlp.vocab.vectors.data.shape[1] output_layer = chain( - Maxout(300, pieces=3, normalize=True, dropout=0.0), Linear(output_size) + Maxout(nO=300, nP=3, normalize=True, dropout=0.0), Linear(output_size) ) # This is annoying, but the parser etc have the flatten step after # the tok2vec. To load the weights in cleanly, we need to match @@ -289,11 +233,12 @@ def create_pretraining_model(nlp, tok2vec): # "tok2vec" has to be the same set of processes as what the components do. tok2vec = chain(tok2vec, list2array()) model = chain(tok2vec, output_layer) - model = build_masked_language_model(nlp.vocab, model) - model.set_ref("tok2vec", tok2vec) - model.set_ref("output_layer", output_layer) model.initialize(X=[nlp.make_doc("Give it a doc to infer shapes")]) - return model + mlm_model = build_masked_language_model(nlp.vocab, model) + mlm_model.set_ref("tok2vec", tok2vec) + mlm_model.set_ref("output_layer", output_layer) + mlm_model.initialize(X=[nlp.make_doc("Give it a doc to infer shapes")]) + return mlm_model class ProgressTracker(object): diff --git a/spacy/errors.py b/spacy/errors.py index 852c55225..96b323ef5 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -441,8 +441,6 @@ class Errors(object): "should be of equal length.") E141 = ("Entity vectors should be of length {required} instead of the " "provided {found}.") - E142 = ("Unsupported loss_function '{loss_func}'. Use either 'L2' or " - "'cosine'.") E143 = ("Labels for component '{name}' not initialized. Did you forget to " "call add_label()?") E144 = ("Could not find parameter `{param}` when building the entity " diff --git a/spacy/ml/models/multi_task.py b/spacy/ml/models/multi_task.py index 1c193df82..970d31899 100644 --- a/spacy/ml/models/multi_task.py +++ b/spacy/ml/models/multi_task.py @@ -1,4 +1,6 @@ -from thinc.api import chain, Maxout, LayerNorm, Softmax, Linear, zero_init +import numpy + +from thinc.api import chain, Maxout, LayerNorm, Softmax, Linear, zero_init, Model def build_multi_task_model(n_tags, tok2vec=None, token_vector_width=96): @@ -24,6 +26,80 @@ def build_cloze_multi_task_model(vocab, tok2vec): return model -def build_masked_language_model(*args, **kwargs): - # TODO cf https://github.com/explosion/spaCy/blob/2c107f02a4d60bda2440db0aad1a88cbbf4fb52d/spacy/_ml.py#L828 - raise NotImplementedError +def build_masked_language_model(vocab, wrapped_model, mask_prob=0.15): + """Convert a model into a BERT-style masked language model""" + + random_words = _RandomWords(vocab) + + def mlm_forward(model, docs, is_train): + mask, docs = _apply_mask(docs, random_words, mask_prob=mask_prob) + mask = model.ops.asarray(mask).reshape((mask.shape[0], 1)) + output, backprop = model.get_ref("wrapped-model").begin_update(docs) # drop=drop + + def mlm_backward(d_output): + d_output *= 1 - mask + return backprop(d_output) + + return output, mlm_backward + + mlm_model = Model("masked-language-model", mlm_forward, layers=[wrapped_model]) + mlm_model.set_ref("wrapped-model", wrapped_model) + + return mlm_model + + +class _RandomWords(object): + def __init__(self, vocab): + self.words = [lex.text for lex in vocab if lex.prob != 0.0] + self.probs = [lex.prob for lex in vocab if lex.prob != 0.0] + self.words = self.words[:10000] + self.probs = self.probs[:10000] + self.probs = numpy.exp(numpy.array(self.probs, dtype="f")) + self.probs /= self.probs.sum() + self._cache = [] + + def next(self): + if not self._cache: + self._cache.extend( + numpy.random.choice(len(self.words), 10000, p=self.probs) + ) + index = self._cache.pop() + return self.words[index] + + +def _apply_mask(docs, random_words, mask_prob=0.15): + # This needs to be here to avoid circular imports + from ...tokens import Doc + + N = sum(len(doc) for doc in docs) + mask = numpy.random.uniform(0.0, 1.0, (N,)) + mask = mask >= mask_prob + i = 0 + masked_docs = [] + for doc in docs: + words = [] + for token in doc: + if not mask[i]: + word = _replace_word(token.text, random_words) + else: + word = token.text + words.append(word) + i += 1 + spaces = [bool(w.whitespace_) for w in doc] + # NB: If you change this implementation to instead modify + # the docs in place, take care that the IDs reflect the original + # words. Currently we use the original docs to make the vectors + # for the target, so we don't lose the original tokens. But if + # you modified the docs in place here, you would. + masked_docs.append(Doc(doc.vocab, words=words, spaces=spaces)) + return mask, masked_docs + + +def _replace_word(word, random_words, mask="[MASK]"): + roll = numpy.random.random() + if roll < 0.8: + return mask + elif roll < 0.9: + return random_words.next() + else: + return word \ No newline at end of file