mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-13 05:07:03 +03:00
Refactor pretrain and support character-based objective for v3 (#5706)
* Start adding character-based stuff * Start adding character-based objective * Start adding character-based stuff * Start adding character-based objective * Remove outdated comment * Update pretraining models * Add/fix character-based multi-task models * Refactor pretrain and support character-based objective * Update pretrain config * Remove unused * Fix flake8 errors * Clean up imports * Format * Format * Update Thinc version * Raise error if vectors objective but no vectors
This commit is contained in:
parent
cdf9ee1716
commit
2bd1bf81f1
|
@ -54,6 +54,10 @@ seed = ${training:seed}
|
||||||
use_pytorch_for_gpu_memory = ${training:use_pytorch_for_gpu_memory}
|
use_pytorch_for_gpu_memory = ${training:use_pytorch_for_gpu_memory}
|
||||||
tok2vec_model = "nlp.pipeline.tok2vec.model"
|
tok2vec_model = "nlp.pipeline.tok2vec.model"
|
||||||
|
|
||||||
|
[pretraining.objective]
|
||||||
|
type = "characters"
|
||||||
|
n_characters = 4
|
||||||
|
|
||||||
[pretraining.optimizer]
|
[pretraining.optimizer]
|
||||||
@optimizers = "Adam.v1"
|
@optimizers = "Adam.v1"
|
||||||
beta1 = 0.9
|
beta1 = 0.9
|
||||||
|
@ -65,10 +69,6 @@ use_averages = true
|
||||||
eps = 1e-8
|
eps = 1e-8
|
||||||
learn_rate = 0.001
|
learn_rate = 0.001
|
||||||
|
|
||||||
[pretraining.loss_func]
|
|
||||||
@losses = "CosineDistance.v1"
|
|
||||||
normalize = true
|
|
||||||
|
|
||||||
[nlp]
|
[nlp]
|
||||||
lang = "en"
|
lang = "en"
|
||||||
vectors = null
|
vectors = null
|
||||||
|
|
|
@ -6,7 +6,7 @@ requires = [
|
||||||
"cymem>=2.0.2,<2.1.0",
|
"cymem>=2.0.2,<2.1.0",
|
||||||
"preshed>=3.0.2,<3.1.0",
|
"preshed>=3.0.2,<3.1.0",
|
||||||
"murmurhash>=0.28.0,<1.1.0",
|
"murmurhash>=0.28.0,<1.1.0",
|
||||||
"thinc==8.0.0a11",
|
"thinc>=8.0.0a12,<8.0.0a20",
|
||||||
"blis>=0.4.0,<0.5.0"
|
"blis>=0.4.0,<0.5.0"
|
||||||
]
|
]
|
||||||
build-backend = "setuptools.build_meta"
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
# Our libraries
|
# Our libraries
|
||||||
cymem>=2.0.2,<2.1.0
|
cymem>=2.0.2,<2.1.0
|
||||||
preshed>=3.0.2,<3.1.0
|
preshed>=3.0.2,<3.1.0
|
||||||
thinc==8.0.0a11
|
thinc>=8.0.0a12,<8.0.0a20
|
||||||
blis>=0.4.0,<0.5.0
|
blis>=0.4.0,<0.5.0
|
||||||
ml_datasets>=0.1.1
|
ml_datasets>=0.1.1
|
||||||
murmurhash>=0.28.0,<1.1.0
|
murmurhash>=0.28.0,<1.1.0
|
||||||
|
|
|
@ -34,13 +34,13 @@ setup_requires =
|
||||||
cymem>=2.0.2,<2.1.0
|
cymem>=2.0.2,<2.1.0
|
||||||
preshed>=3.0.2,<3.1.0
|
preshed>=3.0.2,<3.1.0
|
||||||
murmurhash>=0.28.0,<1.1.0
|
murmurhash>=0.28.0,<1.1.0
|
||||||
thinc==8.0.0a11
|
thinc>=8.0.0a12,<8.0.0a20
|
||||||
install_requires =
|
install_requires =
|
||||||
# Our libraries
|
# Our libraries
|
||||||
murmurhash>=0.28.0,<1.1.0
|
murmurhash>=0.28.0,<1.1.0
|
||||||
cymem>=2.0.2,<2.1.0
|
cymem>=2.0.2,<2.1.0
|
||||||
preshed>=3.0.2,<3.1.0
|
preshed>=3.0.2,<3.1.0
|
||||||
thinc==8.0.0a11
|
thinc>=8.0.0a11,<8.0.0a20
|
||||||
blis>=0.4.0,<0.5.0
|
blis>=0.4.0,<0.5.0
|
||||||
wasabi>=0.7.0,<1.1.0
|
wasabi>=0.7.0,<1.1.0
|
||||||
srsly>=2.1.0,<3.0.0
|
srsly>=2.1.0,<3.0.0
|
||||||
|
|
|
@ -5,13 +5,17 @@ import time
|
||||||
import re
|
import re
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from thinc.api import Linear, Maxout, chain, list2array, use_pytorch_for_gpu_memory
|
from thinc.api import use_pytorch_for_gpu_memory
|
||||||
|
from thinc.api import set_dropout_rate, to_categorical
|
||||||
|
from thinc.api import CosineDistance, L2Distance
|
||||||
from wasabi import msg
|
from wasabi import msg
|
||||||
import srsly
|
import srsly
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
from ._app import app, Arg, Opt
|
from ._app import app, Arg, Opt
|
||||||
from ..errors import Errors
|
from ..errors import Errors
|
||||||
from ..ml.models.multi_task import build_masked_language_model
|
from ..ml.models.multi_task import build_cloze_multi_task_model
|
||||||
|
from ..ml.models.multi_task import build_cloze_characters_multi_task_model
|
||||||
from ..tokens import Doc
|
from ..tokens import Doc
|
||||||
from ..attrs import ID, HEAD
|
from ..attrs import ID, HEAD
|
||||||
from .. import util
|
from .. import util
|
||||||
|
@ -21,7 +25,6 @@ from .. import util
|
||||||
def pretrain_cli(
|
def pretrain_cli(
|
||||||
# fmt: off
|
# fmt: off
|
||||||
texts_loc: Path = Arg(..., help="Path to JSONL file with raw texts to learn from, with text provided as the key 'text' or tokens as the key 'tokens'", exists=True),
|
texts_loc: Path = Arg(..., help="Path to JSONL file with raw texts to learn from, with text provided as the key 'text' or tokens as the key 'tokens'", exists=True),
|
||||||
vectors_model: str = Arg(..., help="Name or path to spaCy model with vectors to learn from"),
|
|
||||||
output_dir: Path = Arg(..., help="Directory to write models to on each epoch"),
|
output_dir: Path = Arg(..., help="Directory to write models to on each epoch"),
|
||||||
config_path: Path = Arg(..., help="Path to config file", exists=True, dir_okay=False),
|
config_path: Path = Arg(..., help="Path to config file", exists=True, dir_okay=False),
|
||||||
use_gpu: int = Opt(-1, "--use-gpu", "-g", help="Use GPU"),
|
use_gpu: int = Opt(-1, "--use-gpu", "-g", help="Use GPU"),
|
||||||
|
@ -31,11 +34,15 @@ def pretrain_cli(
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Pre-train the 'token-to-vector' (tok2vec) layer of pipeline components,
|
Pre-train the 'token-to-vector' (tok2vec) layer of pipeline components,
|
||||||
using an approximate language-modelling objective. Specifically, we load
|
using an approximate language-modelling objective. Two objective types
|
||||||
pretrained vectors, and train a component like a CNN, BiLSTM, etc to predict
|
are available, vector-based and character-based.
|
||||||
vectors which match the pretrained ones. The weights are saved to a directory
|
|
||||||
after each epoch. You can then pass a path to one of these pretrained weights
|
In the vector-based objective, we load word vectors that have been trained
|
||||||
files to the 'spacy train' command.
|
using a word2vec-style distributional similarity algorithm, and train a
|
||||||
|
component like a CNN, BiLSTM, etc to predict vectors which match the
|
||||||
|
pretrained ones. The weights are saved to a directory after each epoch. You
|
||||||
|
can then pass a path to one of these pretrained weights files to the
|
||||||
|
'spacy train' command.
|
||||||
|
|
||||||
This technique may be especially helpful if you have little labelled data.
|
This technique may be especially helpful if you have little labelled data.
|
||||||
However, it's still quite experimental, so your mileage may vary.
|
However, it's still quite experimental, so your mileage may vary.
|
||||||
|
@ -46,7 +53,6 @@ def pretrain_cli(
|
||||||
"""
|
"""
|
||||||
pretrain(
|
pretrain(
|
||||||
texts_loc,
|
texts_loc,
|
||||||
vectors_model,
|
|
||||||
output_dir,
|
output_dir,
|
||||||
config_path,
|
config_path,
|
||||||
use_gpu=use_gpu,
|
use_gpu=use_gpu,
|
||||||
|
@ -57,15 +63,16 @@ def pretrain_cli(
|
||||||
|
|
||||||
def pretrain(
|
def pretrain(
|
||||||
texts_loc: Path,
|
texts_loc: Path,
|
||||||
vectors_model: str,
|
|
||||||
output_dir: Path,
|
output_dir: Path,
|
||||||
config_path: Path,
|
config_path: Path,
|
||||||
use_gpu: int = -1,
|
use_gpu: int = -1,
|
||||||
resume_path: Optional[Path] = None,
|
resume_path: Optional[Path] = None,
|
||||||
epoch_resume: Optional[int] = None,
|
epoch_resume: Optional[int] = None,
|
||||||
):
|
):
|
||||||
if not config_path or not config_path.exists():
|
verify_cli_args(**locals())
|
||||||
msg.fail("Config file not found", config_path, exits=1)
|
if not output_dir.exists():
|
||||||
|
output_dir.mkdir()
|
||||||
|
msg.good(f"Created output directory: {output_dir}")
|
||||||
|
|
||||||
if use_gpu >= 0:
|
if use_gpu >= 0:
|
||||||
msg.info("Using GPU")
|
msg.info("Using GPU")
|
||||||
|
@ -76,82 +83,35 @@ def pretrain(
|
||||||
msg.info(f"Loading config from: {config_path}")
|
msg.info(f"Loading config from: {config_path}")
|
||||||
config = util.load_config(config_path, create_objects=False)
|
config = util.load_config(config_path, create_objects=False)
|
||||||
util.fix_random_seed(config["pretraining"]["seed"])
|
util.fix_random_seed(config["pretraining"]["seed"])
|
||||||
if config["pretraining"]["use_pytorch_for_gpu_memory"]:
|
if use_gpu >= 0 and config["pretraining"]["use_pytorch_for_gpu_memory"]:
|
||||||
use_pytorch_for_gpu_memory()
|
use_pytorch_for_gpu_memory()
|
||||||
|
|
||||||
if output_dir.exists() and [p for p in output_dir.iterdir()]:
|
nlp_config = config["nlp"]
|
||||||
if resume_path:
|
|
||||||
msg.warn(
|
|
||||||
"Output directory is not empty. ",
|
|
||||||
"If you're resuming a run from a previous model in this directory, "
|
|
||||||
"the old models for the consecutive epochs will be overwritten "
|
|
||||||
"with the new ones.",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
msg.warn(
|
|
||||||
"Output directory is not empty. ",
|
|
||||||
"It is better to use an empty directory or refer to a new output path, "
|
|
||||||
"then the new directory will be created for you.",
|
|
||||||
)
|
|
||||||
if not output_dir.exists():
|
|
||||||
output_dir.mkdir()
|
|
||||||
msg.good(f"Created output directory: {output_dir}")
|
|
||||||
srsly.write_json(output_dir / "config.json", config)
|
srsly.write_json(output_dir / "config.json", config)
|
||||||
msg.good("Saved config file in the output directory")
|
msg.good("Saved config file in the output directory")
|
||||||
|
|
||||||
config = util.load_config(config_path, create_objects=True)
|
config = util.load_config(config_path, create_objects=True)
|
||||||
|
nlp = util.load_model_from_config(nlp_config)
|
||||||
pretrain_config = config["pretraining"]
|
pretrain_config = config["pretraining"]
|
||||||
|
|
||||||
# Load texts from file or stdin
|
|
||||||
if texts_loc != "-": # reading from a file
|
if texts_loc != "-": # reading from a file
|
||||||
texts_loc = Path(texts_loc)
|
|
||||||
if not texts_loc.exists():
|
|
||||||
msg.fail("Input text file doesn't exist", texts_loc, exits=1)
|
|
||||||
with msg.loading("Loading input texts..."):
|
with msg.loading("Loading input texts..."):
|
||||||
texts = list(srsly.read_jsonl(texts_loc))
|
texts = list(srsly.read_jsonl(texts_loc))
|
||||||
if not texts:
|
|
||||||
msg.fail("Input file is empty", texts_loc, exits=1)
|
|
||||||
msg.good("Loaded input texts")
|
|
||||||
random.shuffle(texts)
|
random.shuffle(texts)
|
||||||
else: # reading from stdin
|
else: # reading from stdin
|
||||||
msg.info("Reading input text from stdin...")
|
msg.info("Reading input text from stdin...")
|
||||||
texts = srsly.read_jsonl("-")
|
texts = srsly.read_jsonl("-")
|
||||||
|
|
||||||
with msg.loading(f"Loading model '{vectors_model}'..."):
|
|
||||||
nlp = util.load_model(vectors_model)
|
|
||||||
msg.good(f"Loaded model '{vectors_model}'")
|
|
||||||
tok2vec_path = pretrain_config["tok2vec_model"]
|
tok2vec_path = pretrain_config["tok2vec_model"]
|
||||||
tok2vec = config
|
tok2vec = config
|
||||||
for subpath in tok2vec_path.split("."):
|
for subpath in tok2vec_path.split("."):
|
||||||
tok2vec = tok2vec.get(subpath)
|
tok2vec = tok2vec.get(subpath)
|
||||||
model = create_pretraining_model(nlp, tok2vec)
|
model = create_pretraining_model(nlp, tok2vec, pretrain_config)
|
||||||
optimizer = pretrain_config["optimizer"]
|
optimizer = pretrain_config["optimizer"]
|
||||||
|
|
||||||
# Load in pretrained weights to resume from
|
# Load in pretrained weights to resume from
|
||||||
if resume_path is not None:
|
if resume_path is not None:
|
||||||
msg.info(f"Resume training tok2vec from: {resume_path}")
|
_resume_model(model, resume_path, epoch_resume)
|
||||||
with resume_path.open("rb") as file_:
|
|
||||||
weights_data = file_.read()
|
|
||||||
model.get_ref("tok2vec").from_bytes(weights_data)
|
|
||||||
# Parse the epoch number from the given weight file
|
|
||||||
model_name = re.search(r"model\d+\.bin", str(resume_path))
|
|
||||||
if model_name:
|
|
||||||
# Default weight file name so read epoch_start from it by cutting off 'model' and '.bin'
|
|
||||||
epoch_resume = int(model_name.group(0)[5:][:-4]) + 1
|
|
||||||
msg.info(f"Resuming from epoch: {epoch_resume}")
|
|
||||||
else:
|
|
||||||
if not epoch_resume:
|
|
||||||
msg.fail(
|
|
||||||
"You have to use the --epoch-resume setting when using a renamed weight file for --resume-path",
|
|
||||||
exits=True,
|
|
||||||
)
|
|
||||||
elif epoch_resume < 0:
|
|
||||||
msg.fail(
|
|
||||||
f"The argument --epoch-resume has to be greater or equal to 0. {epoch_resume} is invalid",
|
|
||||||
exits=True,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
msg.info(f"Resuming from epoch: {epoch_resume}")
|
|
||||||
else:
|
else:
|
||||||
# Without '--resume-path' the '--epoch-resume' argument is ignored
|
# Without '--resume-path' the '--epoch-resume' argument is ignored
|
||||||
epoch_resume = 0
|
epoch_resume = 0
|
||||||
|
@ -176,7 +136,7 @@ def pretrain(
|
||||||
file_.write(srsly.json_dumps(log) + "\n")
|
file_.write(srsly.json_dumps(log) + "\n")
|
||||||
|
|
||||||
skip_counter = 0
|
skip_counter = 0
|
||||||
loss_func = pretrain_config["loss_func"]
|
objective = create_objective(pretrain_config["objective"])
|
||||||
for epoch in range(epoch_resume, pretrain_config["max_epochs"]):
|
for epoch in range(epoch_resume, pretrain_config["max_epochs"]):
|
||||||
batches = util.minibatch_by_words(texts, size=pretrain_config["batch_size"])
|
batches = util.minibatch_by_words(texts, size=pretrain_config["batch_size"])
|
||||||
for batch_id, batch in enumerate(batches):
|
for batch_id, batch in enumerate(batches):
|
||||||
|
@ -187,7 +147,7 @@ def pretrain(
|
||||||
min_length=pretrain_config["min_length"],
|
min_length=pretrain_config["min_length"],
|
||||||
)
|
)
|
||||||
skip_counter += count
|
skip_counter += count
|
||||||
loss = make_update(model, docs, optimizer, distance=loss_func)
|
loss = make_update(model, docs, optimizer, objective)
|
||||||
progress = tracker.update(epoch, loss, docs)
|
progress = tracker.update(epoch, loss, docs)
|
||||||
if progress:
|
if progress:
|
||||||
msg.row(progress, **row_settings)
|
msg.row(progress, **row_settings)
|
||||||
|
@ -207,7 +167,22 @@ def pretrain(
|
||||||
msg.good("Successfully finished pretrain")
|
msg.good("Successfully finished pretrain")
|
||||||
|
|
||||||
|
|
||||||
def make_update(model, docs, optimizer, distance):
|
def _resume_model(model, resume_path, epoch_resume):
|
||||||
|
msg.info(f"Resume training tok2vec from: {resume_path}")
|
||||||
|
with resume_path.open("rb") as file_:
|
||||||
|
weights_data = file_.read()
|
||||||
|
model.get_ref("tok2vec").from_bytes(weights_data)
|
||||||
|
# Parse the epoch number from the given weight file
|
||||||
|
model_name = re.search(r"model\d+\.bin", str(resume_path))
|
||||||
|
if model_name:
|
||||||
|
# Default weight file name so read epoch_start from it by cutting off 'model' and '.bin'
|
||||||
|
epoch_resume = int(model_name.group(0)[5:][:-4]) + 1
|
||||||
|
msg.info(f"Resuming from epoch: {epoch_resume}")
|
||||||
|
else:
|
||||||
|
msg.info(f"Resuming from epoch: {epoch_resume}")
|
||||||
|
|
||||||
|
|
||||||
|
def make_update(model, docs, optimizer, objective_func):
|
||||||
"""Perform an update over a single batch of documents.
|
"""Perform an update over a single batch of documents.
|
||||||
|
|
||||||
docs (iterable): A batch of `Doc` objects.
|
docs (iterable): A batch of `Doc` objects.
|
||||||
|
@ -215,7 +190,7 @@ def make_update(model, docs, optimizer, distance):
|
||||||
RETURNS loss: A float for the loss.
|
RETURNS loss: A float for the loss.
|
||||||
"""
|
"""
|
||||||
predictions, backprop = model.begin_update(docs)
|
predictions, backprop = model.begin_update(docs)
|
||||||
loss, gradients = get_vectors_loss(model.ops, docs, predictions, distance)
|
loss, gradients = objective_func(model.ops, docs, predictions)
|
||||||
backprop(gradients)
|
backprop(gradients)
|
||||||
model.finish_update(optimizer)
|
model.finish_update(optimizer)
|
||||||
# Don't want to return a cupy object here
|
# Don't want to return a cupy object here
|
||||||
|
@ -254,13 +229,38 @@ def make_docs(nlp, batch, min_length, max_length):
|
||||||
return docs, skip_count
|
return docs, skip_count
|
||||||
|
|
||||||
|
|
||||||
def get_vectors_loss(ops, docs, prediction, distance):
|
def create_objective(config):
|
||||||
"""Compute a mean-squared error loss between the documents' vectors and
|
"""Create the objective for pretraining.
|
||||||
the prediction.
|
|
||||||
|
We'd like to replace this with a registry function but it's tricky because
|
||||||
|
we're also making a model choice based on this. For now we hard-code support
|
||||||
|
for two types (characters, vectors). For characters you can specify
|
||||||
|
n_characters, for vectors you can specify the loss.
|
||||||
|
|
||||||
|
Bleh.
|
||||||
|
"""
|
||||||
|
objective_type = config["type"]
|
||||||
|
if objective_type == "characters":
|
||||||
|
return partial(get_characters_loss, nr_char=config["n_characters"])
|
||||||
|
elif objective_type == "vectors":
|
||||||
|
if config["loss"] == "cosine":
|
||||||
|
return partial(
|
||||||
|
get_vectors_loss,
|
||||||
|
distance=CosineDistance(normalize=True, ignore_zeros=True),
|
||||||
|
)
|
||||||
|
elif config["loss"] == "L2":
|
||||||
|
return partial(
|
||||||
|
get_vectors_loss, distance=L2Distance(normalize=True, ignore_zeros=True)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("Unexpected loss type", config["loss"])
|
||||||
|
else:
|
||||||
|
raise ValueError("Unexpected objective_type", objective_type)
|
||||||
|
|
||||||
Note that this is ripe for customization! We could compute the vectors
|
|
||||||
in some other word, e.g. with an LSTM language model, or use some other
|
def get_vectors_loss(ops, docs, prediction, distance):
|
||||||
type of objective.
|
"""Compute a loss based on a distance between the documents' vectors and
|
||||||
|
the prediction.
|
||||||
"""
|
"""
|
||||||
# The simplest way to implement this would be to vstack the
|
# The simplest way to implement this would be to vstack the
|
||||||
# token.vector values, but that's a bit inefficient, especially on GPU.
|
# token.vector values, but that's a bit inefficient, especially on GPU.
|
||||||
|
@ -272,7 +272,19 @@ def get_vectors_loss(ops, docs, prediction, distance):
|
||||||
return loss, d_target
|
return loss, d_target
|
||||||
|
|
||||||
|
|
||||||
def create_pretraining_model(nlp, tok2vec):
|
def get_characters_loss(ops, docs, prediction, nr_char):
|
||||||
|
"""Compute a loss based on a number of characters predicted from the docs."""
|
||||||
|
target_ids = numpy.vstack([doc.to_utf8_array(nr_char=nr_char) for doc in docs])
|
||||||
|
target_ids = target_ids.reshape((-1,))
|
||||||
|
target = ops.asarray(to_categorical(target_ids, n_classes=256), dtype="f")
|
||||||
|
target = target.reshape((-1, 256 * nr_char))
|
||||||
|
diff = prediction - target
|
||||||
|
loss = (diff ** 2).sum()
|
||||||
|
d_target = diff / float(prediction.shape[0])
|
||||||
|
return loss, d_target
|
||||||
|
|
||||||
|
|
||||||
|
def create_pretraining_model(nlp, tok2vec, pretrain_config):
|
||||||
"""Define a network for the pretraining. We simply add an output layer onto
|
"""Define a network for the pretraining. We simply add an output layer onto
|
||||||
the tok2vec input model. The tok2vec input model needs to be a model that
|
the tok2vec input model. The tok2vec input model needs to be a model that
|
||||||
takes a batch of Doc objects (as a list), and returns a list of arrays.
|
takes a batch of Doc objects (as a list), and returns a list of arrays.
|
||||||
|
@ -280,18 +292,24 @@ def create_pretraining_model(nlp, tok2vec):
|
||||||
The actual tok2vec layer is stored as a reference, and only this bit will be
|
The actual tok2vec layer is stored as a reference, and only this bit will be
|
||||||
serialized to file and read back in when calling the 'train' command.
|
serialized to file and read back in when calling the 'train' command.
|
||||||
"""
|
"""
|
||||||
output_size = nlp.vocab.vectors.data.shape[1]
|
# TODO
|
||||||
output_layer = chain(
|
maxout_pieces = 3
|
||||||
Maxout(nO=300, nP=3, normalize=True, dropout=0.0), Linear(output_size)
|
hidden_size = 300
|
||||||
)
|
if pretrain_config["objective"]["type"] == "vectors":
|
||||||
model = chain(tok2vec, list2array())
|
model = build_cloze_multi_task_model(
|
||||||
model = chain(model, output_layer)
|
nlp.vocab, tok2vec, hidden_size=hidden_size, maxout_pieces=maxout_pieces
|
||||||
|
)
|
||||||
|
elif pretrain_config["objective"]["type"] == "characters":
|
||||||
|
model = build_cloze_characters_multi_task_model(
|
||||||
|
nlp.vocab,
|
||||||
|
tok2vec,
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
maxout_pieces=maxout_pieces,
|
||||||
|
nr_char=pretrain_config["objective"]["n_characters"],
|
||||||
|
)
|
||||||
model.initialize(X=[nlp.make_doc("Give it a doc to infer shapes")])
|
model.initialize(X=[nlp.make_doc("Give it a doc to infer shapes")])
|
||||||
mlm_model = build_masked_language_model(nlp.vocab, model)
|
set_dropout_rate(model, pretrain_config["dropout"])
|
||||||
mlm_model.set_ref("tok2vec", tok2vec)
|
return model
|
||||||
mlm_model.set_ref("output_layer", output_layer)
|
|
||||||
mlm_model.initialize(X=[nlp.make_doc("Give it a doc to infer shapes")])
|
|
||||||
return mlm_model
|
|
||||||
|
|
||||||
|
|
||||||
class ProgressTracker(object):
|
class ProgressTracker(object):
|
||||||
|
@ -340,3 +358,53 @@ def _smart_round(figure, width=10, max_decimal=4):
|
||||||
n_decimal = min(n_decimal, max_decimal)
|
n_decimal = min(n_decimal, max_decimal)
|
||||||
format_str = "%." + str(n_decimal) + "f"
|
format_str = "%." + str(n_decimal) + "f"
|
||||||
return format_str % figure
|
return format_str % figure
|
||||||
|
|
||||||
|
|
||||||
|
def verify_cli_args(
|
||||||
|
texts_loc, output_dir, config_path, use_gpu, resume_path, epoch_resume
|
||||||
|
):
|
||||||
|
if not config_path or not config_path.exists():
|
||||||
|
msg.fail("Config file not found", config_path, exits=1)
|
||||||
|
if output_dir.exists() and [p for p in output_dir.iterdir()]:
|
||||||
|
if resume_path:
|
||||||
|
msg.warn(
|
||||||
|
"Output directory is not empty. ",
|
||||||
|
"If you're resuming a run from a previous model in this directory, "
|
||||||
|
"the old models for the consecutive epochs will be overwritten "
|
||||||
|
"with the new ones.",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
msg.warn(
|
||||||
|
"Output directory is not empty. ",
|
||||||
|
"It is better to use an empty directory or refer to a new output path, "
|
||||||
|
"then the new directory will be created for you.",
|
||||||
|
)
|
||||||
|
if texts_loc != "-": # reading from a file
|
||||||
|
texts_loc = Path(texts_loc)
|
||||||
|
if not texts_loc.exists():
|
||||||
|
msg.fail("Input text file doesn't exist", texts_loc, exits=1)
|
||||||
|
|
||||||
|
for text in srsly.read_jsonl(texts_loc):
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
msg.fail("Input file is empty", texts_loc, exits=1)
|
||||||
|
|
||||||
|
if resume_path is not None:
|
||||||
|
model_name = re.search(r"model\d+\.bin", str(resume_path))
|
||||||
|
if not model_name and not epoch_resume:
|
||||||
|
msg.fail(
|
||||||
|
"You have to use the --epoch-resume setting when using a renamed weight file for --resume-path",
|
||||||
|
exits=True,
|
||||||
|
)
|
||||||
|
elif not model_name and epoch_resume < 0:
|
||||||
|
msg.fail(
|
||||||
|
f"The argument --epoch-resume has to be greater or equal to 0. {epoch_resume} is invalid",
|
||||||
|
exits=True,
|
||||||
|
)
|
||||||
|
config = util.load_config(config_path, create_objects=False)
|
||||||
|
if config["pretraining"]["objective"]["type"] == "vectors":
|
||||||
|
if not config["nlp"]["vectors"]:
|
||||||
|
msg.fail(
|
||||||
|
"Must specify nlp.vectors if pretraining.objective.type is vectors",
|
||||||
|
exits=True
|
||||||
|
)
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import numpy
|
import numpy
|
||||||
|
|
||||||
from thinc.api import chain, Maxout, LayerNorm, Softmax, Linear, zero_init, Model
|
from thinc.api import chain, Maxout, LayerNorm, Softmax, Linear, zero_init, Model
|
||||||
|
from thinc.api import MultiSoftmax, list2array
|
||||||
|
|
||||||
|
|
||||||
def build_multi_task_model(tok2vec, maxout_pieces, token_vector_width, nO=None):
|
def build_multi_task_model(tok2vec, maxout_pieces, token_vector_width, nO=None):
|
||||||
|
@ -21,9 +22,10 @@ def build_multi_task_model(tok2vec, maxout_pieces, token_vector_width, nO=None):
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def build_cloze_multi_task_model(vocab, tok2vec, maxout_pieces, nO=None):
|
def build_cloze_multi_task_model(vocab, tok2vec, maxout_pieces, hidden_size, nO=None):
|
||||||
# nO = vocab.vectors.data.shape[1]
|
# nO = vocab.vectors.data.shape[1]
|
||||||
output_layer = chain(
|
output_layer = chain(
|
||||||
|
list2array(),
|
||||||
Maxout(
|
Maxout(
|
||||||
nO=nO,
|
nO=nO,
|
||||||
nI=tok2vec.get_dim("nO"),
|
nI=tok2vec.get_dim("nO"),
|
||||||
|
@ -40,6 +42,22 @@ def build_cloze_multi_task_model(vocab, tok2vec, maxout_pieces, nO=None):
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def build_cloze_characters_multi_task_model(
|
||||||
|
vocab, tok2vec, maxout_pieces, hidden_size, nr_char
|
||||||
|
):
|
||||||
|
output_layer = chain(
|
||||||
|
list2array(),
|
||||||
|
Maxout(hidden_size, nP=maxout_pieces),
|
||||||
|
LayerNorm(nI=hidden_size),
|
||||||
|
MultiSoftmax([256] * nr_char, nI=hidden_size),
|
||||||
|
)
|
||||||
|
|
||||||
|
model = build_masked_language_model(vocab, chain(tok2vec, output_layer))
|
||||||
|
model.set_ref("tok2vec", tok2vec)
|
||||||
|
model.set_ref("output_layer", output_layer)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
def build_masked_language_model(vocab, wrapped_model, mask_prob=0.15):
|
def build_masked_language_model(vocab, wrapped_model, mask_prob=0.15):
|
||||||
"""Convert a model into a BERT-style masked language model"""
|
"""Convert a model into a BERT-style masked language model"""
|
||||||
|
|
||||||
|
@ -48,7 +66,7 @@ def build_masked_language_model(vocab, wrapped_model, mask_prob=0.15):
|
||||||
def mlm_forward(model, docs, is_train):
|
def mlm_forward(model, docs, is_train):
|
||||||
mask, docs = _apply_mask(docs, random_words, mask_prob=mask_prob)
|
mask, docs = _apply_mask(docs, random_words, mask_prob=mask_prob)
|
||||||
mask = model.ops.asarray(mask).reshape((mask.shape[0], 1))
|
mask = model.ops.asarray(mask).reshape((mask.shape[0], 1))
|
||||||
output, backprop = model.get_ref("wrapped-model").begin_update(docs)
|
output, backprop = model.layers[0](docs, is_train)
|
||||||
|
|
||||||
def mlm_backward(d_output):
|
def mlm_backward(d_output):
|
||||||
d_output *= 1 - mask
|
d_output *= 1 - mask
|
||||||
|
@ -56,8 +74,22 @@ def build_masked_language_model(vocab, wrapped_model, mask_prob=0.15):
|
||||||
|
|
||||||
return output, mlm_backward
|
return output, mlm_backward
|
||||||
|
|
||||||
mlm_model = Model("masked-language-model", mlm_forward, layers=[wrapped_model])
|
def mlm_initialize(model, X=None, Y=None):
|
||||||
mlm_model.set_ref("wrapped-model", wrapped_model)
|
wrapped = model.layers[0]
|
||||||
|
wrapped.initialize(X=X, Y=Y)
|
||||||
|
for dim in wrapped.dim_names:
|
||||||
|
if wrapped.has_dim(dim):
|
||||||
|
model.set_dim(dim, wrapped.get_dim(dim))
|
||||||
|
|
||||||
|
mlm_model = Model(
|
||||||
|
"masked-language-model",
|
||||||
|
mlm_forward,
|
||||||
|
layers=[wrapped_model],
|
||||||
|
init=mlm_initialize,
|
||||||
|
refs={"wrapped": wrapped_model},
|
||||||
|
dims={dim: None for dim in wrapped_model.dim_names},
|
||||||
|
)
|
||||||
|
mlm_model.set_ref("wrapped", wrapped_model)
|
||||||
|
|
||||||
return mlm_model
|
return mlm_model
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user