Refactor pretrain and support character-based objective for v3 (#5706)

* Start adding character-based stuff

* Start adding character-based objective

* Start adding character-based stuff

* Start adding character-based objective

* Remove outdated comment

* Update pretraining models

* Add/fix character-based multi-task models

* Refactor pretrain and support character-based objective

* Update pretrain config

* Remove unused

* Fix flake8 errors

* Clean up imports

* Format

* Format

* Update Thinc version

* Raise error if vectors objective but no vectors
This commit is contained in:
Matthew Honnibal 2020-07-03 17:57:28 +02:00 committed by GitHub
parent cdf9ee1716
commit 2bd1bf81f1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 198 additions and 98 deletions

View File

@ -54,6 +54,10 @@ seed = ${training:seed}
use_pytorch_for_gpu_memory = ${training:use_pytorch_for_gpu_memory}
tok2vec_model = "nlp.pipeline.tok2vec.model"
[pretraining.objective]
type = "characters"
n_characters = 4
[pretraining.optimizer]
@optimizers = "Adam.v1"
beta1 = 0.9
@ -65,10 +69,6 @@ use_averages = true
eps = 1e-8
learn_rate = 0.001
[pretraining.loss_func]
@losses = "CosineDistance.v1"
normalize = true
[nlp]
lang = "en"
vectors = null

View File

@ -6,7 +6,7 @@ requires = [
"cymem>=2.0.2,<2.1.0",
"preshed>=3.0.2,<3.1.0",
"murmurhash>=0.28.0,<1.1.0",
"thinc==8.0.0a11",
"thinc>=8.0.0a12,<8.0.0a20",
"blis>=0.4.0,<0.5.0"
]
build-backend = "setuptools.build_meta"

View File

@ -1,7 +1,7 @@
# Our libraries
cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0
thinc==8.0.0a11
thinc>=8.0.0a12,<8.0.0a20
blis>=0.4.0,<0.5.0
ml_datasets>=0.1.1
murmurhash>=0.28.0,<1.1.0

View File

@ -34,13 +34,13 @@ setup_requires =
cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0
murmurhash>=0.28.0,<1.1.0
thinc==8.0.0a11
thinc>=8.0.0a12,<8.0.0a20
install_requires =
# Our libraries
murmurhash>=0.28.0,<1.1.0
cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0
thinc==8.0.0a11
thinc>=8.0.0a11,<8.0.0a20
blis>=0.4.0,<0.5.0
wasabi>=0.7.0,<1.1.0
srsly>=2.1.0,<3.0.0

View File

@ -5,13 +5,17 @@ import time
import re
from collections import Counter
from pathlib import Path
from thinc.api import Linear, Maxout, chain, list2array, use_pytorch_for_gpu_memory
from thinc.api import use_pytorch_for_gpu_memory
from thinc.api import set_dropout_rate, to_categorical
from thinc.api import CosineDistance, L2Distance
from wasabi import msg
import srsly
from functools import partial
from ._app import app, Arg, Opt
from ..errors import Errors
from ..ml.models.multi_task import build_masked_language_model
from ..ml.models.multi_task import build_cloze_multi_task_model
from ..ml.models.multi_task import build_cloze_characters_multi_task_model
from ..tokens import Doc
from ..attrs import ID, HEAD
from .. import util
@ -21,7 +25,6 @@ from .. import util
def pretrain_cli(
# fmt: off
texts_loc: Path = Arg(..., help="Path to JSONL file with raw texts to learn from, with text provided as the key 'text' or tokens as the key 'tokens'", exists=True),
vectors_model: str = Arg(..., help="Name or path to spaCy model with vectors to learn from"),
output_dir: Path = Arg(..., help="Directory to write models to on each epoch"),
config_path: Path = Arg(..., help="Path to config file", exists=True, dir_okay=False),
use_gpu: int = Opt(-1, "--use-gpu", "-g", help="Use GPU"),
@ -31,11 +34,15 @@ def pretrain_cli(
):
"""
Pre-train the 'token-to-vector' (tok2vec) layer of pipeline components,
using an approximate language-modelling objective. Specifically, we load
pretrained vectors, and train a component like a CNN, BiLSTM, etc to predict
vectors which match the pretrained ones. The weights are saved to a directory
after each epoch. You can then pass a path to one of these pretrained weights
files to the 'spacy train' command.
using an approximate language-modelling objective. Two objective types
are available, vector-based and character-based.
In the vector-based objective, we load word vectors that have been trained
using a word2vec-style distributional similarity algorithm, and train a
component like a CNN, BiLSTM, etc to predict vectors which match the
pretrained ones. The weights are saved to a directory after each epoch. You
can then pass a path to one of these pretrained weights files to the
'spacy train' command.
This technique may be especially helpful if you have little labelled data.
However, it's still quite experimental, so your mileage may vary.
@ -46,7 +53,6 @@ def pretrain_cli(
"""
pretrain(
texts_loc,
vectors_model,
output_dir,
config_path,
use_gpu=use_gpu,
@ -57,15 +63,16 @@ def pretrain_cli(
def pretrain(
texts_loc: Path,
vectors_model: str,
output_dir: Path,
config_path: Path,
use_gpu: int = -1,
resume_path: Optional[Path] = None,
epoch_resume: Optional[int] = None,
):
if not config_path or not config_path.exists():
msg.fail("Config file not found", config_path, exits=1)
verify_cli_args(**locals())
if not output_dir.exists():
output_dir.mkdir()
msg.good(f"Created output directory: {output_dir}")
if use_gpu >= 0:
msg.info("Using GPU")
@ -76,82 +83,35 @@ def pretrain(
msg.info(f"Loading config from: {config_path}")
config = util.load_config(config_path, create_objects=False)
util.fix_random_seed(config["pretraining"]["seed"])
if config["pretraining"]["use_pytorch_for_gpu_memory"]:
if use_gpu >= 0 and config["pretraining"]["use_pytorch_for_gpu_memory"]:
use_pytorch_for_gpu_memory()
if output_dir.exists() and [p for p in output_dir.iterdir()]:
if resume_path:
msg.warn(
"Output directory is not empty. ",
"If you're resuming a run from a previous model in this directory, "
"the old models for the consecutive epochs will be overwritten "
"with the new ones.",
)
else:
msg.warn(
"Output directory is not empty. ",
"It is better to use an empty directory or refer to a new output path, "
"then the new directory will be created for you.",
)
if not output_dir.exists():
output_dir.mkdir()
msg.good(f"Created output directory: {output_dir}")
nlp_config = config["nlp"]
srsly.write_json(output_dir / "config.json", config)
msg.good("Saved config file in the output directory")
config = util.load_config(config_path, create_objects=True)
nlp = util.load_model_from_config(nlp_config)
pretrain_config = config["pretraining"]
# Load texts from file or stdin
if texts_loc != "-": # reading from a file
texts_loc = Path(texts_loc)
if not texts_loc.exists():
msg.fail("Input text file doesn't exist", texts_loc, exits=1)
with msg.loading("Loading input texts..."):
texts = list(srsly.read_jsonl(texts_loc))
if not texts:
msg.fail("Input file is empty", texts_loc, exits=1)
msg.good("Loaded input texts")
random.shuffle(texts)
else: # reading from stdin
msg.info("Reading input text from stdin...")
texts = srsly.read_jsonl("-")
with msg.loading(f"Loading model '{vectors_model}'..."):
nlp = util.load_model(vectors_model)
msg.good(f"Loaded model '{vectors_model}'")
tok2vec_path = pretrain_config["tok2vec_model"]
tok2vec = config
for subpath in tok2vec_path.split("."):
tok2vec = tok2vec.get(subpath)
model = create_pretraining_model(nlp, tok2vec)
model = create_pretraining_model(nlp, tok2vec, pretrain_config)
optimizer = pretrain_config["optimizer"]
# Load in pretrained weights to resume from
if resume_path is not None:
msg.info(f"Resume training tok2vec from: {resume_path}")
with resume_path.open("rb") as file_:
weights_data = file_.read()
model.get_ref("tok2vec").from_bytes(weights_data)
# Parse the epoch number from the given weight file
model_name = re.search(r"model\d+\.bin", str(resume_path))
if model_name:
# Default weight file name so read epoch_start from it by cutting off 'model' and '.bin'
epoch_resume = int(model_name.group(0)[5:][:-4]) + 1
msg.info(f"Resuming from epoch: {epoch_resume}")
else:
if not epoch_resume:
msg.fail(
"You have to use the --epoch-resume setting when using a renamed weight file for --resume-path",
exits=True,
)
elif epoch_resume < 0:
msg.fail(
f"The argument --epoch-resume has to be greater or equal to 0. {epoch_resume} is invalid",
exits=True,
)
else:
msg.info(f"Resuming from epoch: {epoch_resume}")
_resume_model(model, resume_path, epoch_resume)
else:
# Without '--resume-path' the '--epoch-resume' argument is ignored
epoch_resume = 0
@ -176,7 +136,7 @@ def pretrain(
file_.write(srsly.json_dumps(log) + "\n")
skip_counter = 0
loss_func = pretrain_config["loss_func"]
objective = create_objective(pretrain_config["objective"])
for epoch in range(epoch_resume, pretrain_config["max_epochs"]):
batches = util.minibatch_by_words(texts, size=pretrain_config["batch_size"])
for batch_id, batch in enumerate(batches):
@ -187,7 +147,7 @@ def pretrain(
min_length=pretrain_config["min_length"],
)
skip_counter += count
loss = make_update(model, docs, optimizer, distance=loss_func)
loss = make_update(model, docs, optimizer, objective)
progress = tracker.update(epoch, loss, docs)
if progress:
msg.row(progress, **row_settings)
@ -207,7 +167,22 @@ def pretrain(
msg.good("Successfully finished pretrain")
def make_update(model, docs, optimizer, distance):
def _resume_model(model, resume_path, epoch_resume):
msg.info(f"Resume training tok2vec from: {resume_path}")
with resume_path.open("rb") as file_:
weights_data = file_.read()
model.get_ref("tok2vec").from_bytes(weights_data)
# Parse the epoch number from the given weight file
model_name = re.search(r"model\d+\.bin", str(resume_path))
if model_name:
# Default weight file name so read epoch_start from it by cutting off 'model' and '.bin'
epoch_resume = int(model_name.group(0)[5:][:-4]) + 1
msg.info(f"Resuming from epoch: {epoch_resume}")
else:
msg.info(f"Resuming from epoch: {epoch_resume}")
def make_update(model, docs, optimizer, objective_func):
"""Perform an update over a single batch of documents.
docs (iterable): A batch of `Doc` objects.
@ -215,7 +190,7 @@ def make_update(model, docs, optimizer, distance):
RETURNS loss: A float for the loss.
"""
predictions, backprop = model.begin_update(docs)
loss, gradients = get_vectors_loss(model.ops, docs, predictions, distance)
loss, gradients = objective_func(model.ops, docs, predictions)
backprop(gradients)
model.finish_update(optimizer)
# Don't want to return a cupy object here
@ -254,13 +229,38 @@ def make_docs(nlp, batch, min_length, max_length):
return docs, skip_count
def get_vectors_loss(ops, docs, prediction, distance):
"""Compute a mean-squared error loss between the documents' vectors and
the prediction.
def create_objective(config):
"""Create the objective for pretraining.
We'd like to replace this with a registry function but it's tricky because
we're also making a model choice based on this. For now we hard-code support
for two types (characters, vectors). For characters you can specify
n_characters, for vectors you can specify the loss.
Bleh.
"""
objective_type = config["type"]
if objective_type == "characters":
return partial(get_characters_loss, nr_char=config["n_characters"])
elif objective_type == "vectors":
if config["loss"] == "cosine":
return partial(
get_vectors_loss,
distance=CosineDistance(normalize=True, ignore_zeros=True),
)
elif config["loss"] == "L2":
return partial(
get_vectors_loss, distance=L2Distance(normalize=True, ignore_zeros=True)
)
else:
raise ValueError("Unexpected loss type", config["loss"])
else:
raise ValueError("Unexpected objective_type", objective_type)
Note that this is ripe for customization! We could compute the vectors
in some other word, e.g. with an LSTM language model, or use some other
type of objective.
def get_vectors_loss(ops, docs, prediction, distance):
"""Compute a loss based on a distance between the documents' vectors and
the prediction.
"""
# The simplest way to implement this would be to vstack the
# token.vector values, but that's a bit inefficient, especially on GPU.
@ -272,7 +272,19 @@ def get_vectors_loss(ops, docs, prediction, distance):
return loss, d_target
def create_pretraining_model(nlp, tok2vec):
def get_characters_loss(ops, docs, prediction, nr_char):
"""Compute a loss based on a number of characters predicted from the docs."""
target_ids = numpy.vstack([doc.to_utf8_array(nr_char=nr_char) for doc in docs])
target_ids = target_ids.reshape((-1,))
target = ops.asarray(to_categorical(target_ids, n_classes=256), dtype="f")
target = target.reshape((-1, 256 * nr_char))
diff = prediction - target
loss = (diff ** 2).sum()
d_target = diff / float(prediction.shape[0])
return loss, d_target
def create_pretraining_model(nlp, tok2vec, pretrain_config):
"""Define a network for the pretraining. We simply add an output layer onto
the tok2vec input model. The tok2vec input model needs to be a model that
takes a batch of Doc objects (as a list), and returns a list of arrays.
@ -280,18 +292,24 @@ def create_pretraining_model(nlp, tok2vec):
The actual tok2vec layer is stored as a reference, and only this bit will be
serialized to file and read back in when calling the 'train' command.
"""
output_size = nlp.vocab.vectors.data.shape[1]
output_layer = chain(
Maxout(nO=300, nP=3, normalize=True, dropout=0.0), Linear(output_size)
)
model = chain(tok2vec, list2array())
model = chain(model, output_layer)
# TODO
maxout_pieces = 3
hidden_size = 300
if pretrain_config["objective"]["type"] == "vectors":
model = build_cloze_multi_task_model(
nlp.vocab, tok2vec, hidden_size=hidden_size, maxout_pieces=maxout_pieces
)
elif pretrain_config["objective"]["type"] == "characters":
model = build_cloze_characters_multi_task_model(
nlp.vocab,
tok2vec,
hidden_size=hidden_size,
maxout_pieces=maxout_pieces,
nr_char=pretrain_config["objective"]["n_characters"],
)
model.initialize(X=[nlp.make_doc("Give it a doc to infer shapes")])
mlm_model = build_masked_language_model(nlp.vocab, model)
mlm_model.set_ref("tok2vec", tok2vec)
mlm_model.set_ref("output_layer", output_layer)
mlm_model.initialize(X=[nlp.make_doc("Give it a doc to infer shapes")])
return mlm_model
set_dropout_rate(model, pretrain_config["dropout"])
return model
class ProgressTracker(object):
@ -340,3 +358,53 @@ def _smart_round(figure, width=10, max_decimal=4):
n_decimal = min(n_decimal, max_decimal)
format_str = "%." + str(n_decimal) + "f"
return format_str % figure
def verify_cli_args(
texts_loc, output_dir, config_path, use_gpu, resume_path, epoch_resume
):
if not config_path or not config_path.exists():
msg.fail("Config file not found", config_path, exits=1)
if output_dir.exists() and [p for p in output_dir.iterdir()]:
if resume_path:
msg.warn(
"Output directory is not empty. ",
"If you're resuming a run from a previous model in this directory, "
"the old models for the consecutive epochs will be overwritten "
"with the new ones.",
)
else:
msg.warn(
"Output directory is not empty. ",
"It is better to use an empty directory or refer to a new output path, "
"then the new directory will be created for you.",
)
if texts_loc != "-": # reading from a file
texts_loc = Path(texts_loc)
if not texts_loc.exists():
msg.fail("Input text file doesn't exist", texts_loc, exits=1)
for text in srsly.read_jsonl(texts_loc):
break
else:
msg.fail("Input file is empty", texts_loc, exits=1)
if resume_path is not None:
model_name = re.search(r"model\d+\.bin", str(resume_path))
if not model_name and not epoch_resume:
msg.fail(
"You have to use the --epoch-resume setting when using a renamed weight file for --resume-path",
exits=True,
)
elif not model_name and epoch_resume < 0:
msg.fail(
f"The argument --epoch-resume has to be greater or equal to 0. {epoch_resume} is invalid",
exits=True,
)
config = util.load_config(config_path, create_objects=False)
if config["pretraining"]["objective"]["type"] == "vectors":
if not config["nlp"]["vectors"]:
msg.fail(
"Must specify nlp.vectors if pretraining.objective.type is vectors",
exits=True
)

View File

@ -1,6 +1,7 @@
import numpy
from thinc.api import chain, Maxout, LayerNorm, Softmax, Linear, zero_init, Model
from thinc.api import MultiSoftmax, list2array
def build_multi_task_model(tok2vec, maxout_pieces, token_vector_width, nO=None):
@ -21,9 +22,10 @@ def build_multi_task_model(tok2vec, maxout_pieces, token_vector_width, nO=None):
return model
def build_cloze_multi_task_model(vocab, tok2vec, maxout_pieces, nO=None):
def build_cloze_multi_task_model(vocab, tok2vec, maxout_pieces, hidden_size, nO=None):
# nO = vocab.vectors.data.shape[1]
output_layer = chain(
list2array(),
Maxout(
nO=nO,
nI=tok2vec.get_dim("nO"),
@ -40,6 +42,22 @@ def build_cloze_multi_task_model(vocab, tok2vec, maxout_pieces, nO=None):
return model
def build_cloze_characters_multi_task_model(
vocab, tok2vec, maxout_pieces, hidden_size, nr_char
):
output_layer = chain(
list2array(),
Maxout(hidden_size, nP=maxout_pieces),
LayerNorm(nI=hidden_size),
MultiSoftmax([256] * nr_char, nI=hidden_size),
)
model = build_masked_language_model(vocab, chain(tok2vec, output_layer))
model.set_ref("tok2vec", tok2vec)
model.set_ref("output_layer", output_layer)
return model
def build_masked_language_model(vocab, wrapped_model, mask_prob=0.15):
"""Convert a model into a BERT-style masked language model"""
@ -48,7 +66,7 @@ def build_masked_language_model(vocab, wrapped_model, mask_prob=0.15):
def mlm_forward(model, docs, is_train):
mask, docs = _apply_mask(docs, random_words, mask_prob=mask_prob)
mask = model.ops.asarray(mask).reshape((mask.shape[0], 1))
output, backprop = model.get_ref("wrapped-model").begin_update(docs)
output, backprop = model.layers[0](docs, is_train)
def mlm_backward(d_output):
d_output *= 1 - mask
@ -56,8 +74,22 @@ def build_masked_language_model(vocab, wrapped_model, mask_prob=0.15):
return output, mlm_backward
mlm_model = Model("masked-language-model", mlm_forward, layers=[wrapped_model])
mlm_model.set_ref("wrapped-model", wrapped_model)
def mlm_initialize(model, X=None, Y=None):
wrapped = model.layers[0]
wrapped.initialize(X=X, Y=Y)
for dim in wrapped.dim_names:
if wrapped.has_dim(dim):
model.set_dim(dim, wrapped.get_dim(dim))
mlm_model = Model(
"masked-language-model",
mlm_forward,
layers=[wrapped_model],
init=mlm_initialize,
refs={"wrapped": wrapped_model},
dims={dim: None for dim in wrapped_model.dim_names},
)
mlm_model.set_ref("wrapped", wrapped_model)
return mlm_model