mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 01:48:04 +03:00 
			
		
		
		
	Refactor pretrain and support character-based objective for v3 (#5706)
* Start adding character-based stuff * Start adding character-based objective * Start adding character-based stuff * Start adding character-based objective * Remove outdated comment * Update pretraining models * Add/fix character-based multi-task models * Refactor pretrain and support character-based objective * Update pretrain config * Remove unused * Fix flake8 errors * Clean up imports * Format * Format * Update Thinc version * Raise error if vectors objective but no vectors
This commit is contained in:
		
							parent
							
								
									cdf9ee1716
								
							
						
					
					
						commit
						2bd1bf81f1
					
				| 
						 | 
				
			
			@ -54,6 +54,10 @@ seed = ${training:seed}
 | 
			
		|||
use_pytorch_for_gpu_memory = ${training:use_pytorch_for_gpu_memory}
 | 
			
		||||
tok2vec_model = "nlp.pipeline.tok2vec.model"
 | 
			
		||||
 | 
			
		||||
[pretraining.objective]
 | 
			
		||||
type = "characters"
 | 
			
		||||
n_characters = 4
 | 
			
		||||
 | 
			
		||||
[pretraining.optimizer]
 | 
			
		||||
@optimizers = "Adam.v1"
 | 
			
		||||
beta1 = 0.9
 | 
			
		||||
| 
						 | 
				
			
			@ -65,10 +69,6 @@ use_averages = true
 | 
			
		|||
eps = 1e-8
 | 
			
		||||
learn_rate = 0.001
 | 
			
		||||
 | 
			
		||||
[pretraining.loss_func]
 | 
			
		||||
@losses = "CosineDistance.v1"
 | 
			
		||||
normalize = true
 | 
			
		||||
 | 
			
		||||
[nlp]
 | 
			
		||||
lang = "en"
 | 
			
		||||
vectors = null
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -6,7 +6,7 @@ requires = [
 | 
			
		|||
    "cymem>=2.0.2,<2.1.0",
 | 
			
		||||
    "preshed>=3.0.2,<3.1.0",
 | 
			
		||||
    "murmurhash>=0.28.0,<1.1.0",
 | 
			
		||||
    "thinc==8.0.0a11",
 | 
			
		||||
    "thinc>=8.0.0a12,<8.0.0a20",
 | 
			
		||||
    "blis>=0.4.0,<0.5.0"
 | 
			
		||||
]
 | 
			
		||||
build-backend = "setuptools.build_meta"
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,7 +1,7 @@
 | 
			
		|||
# Our libraries
 | 
			
		||||
cymem>=2.0.2,<2.1.0
 | 
			
		||||
preshed>=3.0.2,<3.1.0
 | 
			
		||||
thinc==8.0.0a11
 | 
			
		||||
thinc>=8.0.0a12,<8.0.0a20
 | 
			
		||||
blis>=0.4.0,<0.5.0
 | 
			
		||||
ml_datasets>=0.1.1
 | 
			
		||||
murmurhash>=0.28.0,<1.1.0
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -34,13 +34,13 @@ setup_requires =
 | 
			
		|||
    cymem>=2.0.2,<2.1.0
 | 
			
		||||
    preshed>=3.0.2,<3.1.0
 | 
			
		||||
    murmurhash>=0.28.0,<1.1.0
 | 
			
		||||
    thinc==8.0.0a11
 | 
			
		||||
    thinc>=8.0.0a12,<8.0.0a20
 | 
			
		||||
install_requires =
 | 
			
		||||
    # Our libraries
 | 
			
		||||
    murmurhash>=0.28.0,<1.1.0
 | 
			
		||||
    cymem>=2.0.2,<2.1.0
 | 
			
		||||
    preshed>=3.0.2,<3.1.0
 | 
			
		||||
    thinc==8.0.0a11
 | 
			
		||||
    thinc>=8.0.0a11,<8.0.0a20
 | 
			
		||||
    blis>=0.4.0,<0.5.0
 | 
			
		||||
    wasabi>=0.7.0,<1.1.0
 | 
			
		||||
    srsly>=2.1.0,<3.0.0
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -5,13 +5,17 @@ import time
 | 
			
		|||
import re
 | 
			
		||||
from collections import Counter
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from thinc.api import Linear, Maxout, chain, list2array, use_pytorch_for_gpu_memory
 | 
			
		||||
from thinc.api import use_pytorch_for_gpu_memory
 | 
			
		||||
from thinc.api import set_dropout_rate, to_categorical
 | 
			
		||||
from thinc.api import CosineDistance, L2Distance
 | 
			
		||||
from wasabi import msg
 | 
			
		||||
import srsly
 | 
			
		||||
from functools import partial
 | 
			
		||||
 | 
			
		||||
from ._app import app, Arg, Opt
 | 
			
		||||
from ..errors import Errors
 | 
			
		||||
from ..ml.models.multi_task import build_masked_language_model
 | 
			
		||||
from ..ml.models.multi_task import build_cloze_multi_task_model
 | 
			
		||||
from ..ml.models.multi_task import build_cloze_characters_multi_task_model
 | 
			
		||||
from ..tokens import Doc
 | 
			
		||||
from ..attrs import ID, HEAD
 | 
			
		||||
from .. import util
 | 
			
		||||
| 
						 | 
				
			
			@ -21,7 +25,6 @@ from .. import util
 | 
			
		|||
def pretrain_cli(
 | 
			
		||||
    # fmt: off
 | 
			
		||||
    texts_loc: Path = Arg(..., help="Path to JSONL file with raw texts to learn from, with text provided as the key 'text' or tokens as the key 'tokens'", exists=True),
 | 
			
		||||
    vectors_model: str = Arg(..., help="Name or path to spaCy model with vectors to learn from"),
 | 
			
		||||
    output_dir: Path = Arg(..., help="Directory to write models to on each epoch"),
 | 
			
		||||
    config_path: Path = Arg(..., help="Path to config file", exists=True, dir_okay=False),
 | 
			
		||||
    use_gpu: int = Opt(-1, "--use-gpu", "-g", help="Use GPU"),
 | 
			
		||||
| 
						 | 
				
			
			@ -31,11 +34,15 @@ def pretrain_cli(
 | 
			
		|||
):
 | 
			
		||||
    """
 | 
			
		||||
    Pre-train the 'token-to-vector' (tok2vec) layer of pipeline components,
 | 
			
		||||
    using an approximate language-modelling objective. Specifically, we load
 | 
			
		||||
    pretrained vectors, and train a component like a CNN, BiLSTM, etc to predict
 | 
			
		||||
    vectors which match the pretrained ones. The weights are saved to a directory
 | 
			
		||||
    after each epoch. You can then pass a path to one of these pretrained weights
 | 
			
		||||
    files to the 'spacy train' command.
 | 
			
		||||
    using an approximate language-modelling objective. Two objective types
 | 
			
		||||
    are available, vector-based and character-based.
 | 
			
		||||
    
 | 
			
		||||
    In the vector-based objective, we load word vectors that have been trained
 | 
			
		||||
    using a word2vec-style distributional similarity algorithm, and train a
 | 
			
		||||
    component like a CNN, BiLSTM, etc to predict vectors which match the
 | 
			
		||||
    pretrained ones. The weights are saved to a directory after each epoch. You
 | 
			
		||||
    can then pass a path to one of these pretrained weights files to the
 | 
			
		||||
    'spacy train' command.
 | 
			
		||||
 | 
			
		||||
    This technique may be especially helpful if you have little labelled data.
 | 
			
		||||
    However, it's still quite experimental, so your mileage may vary.
 | 
			
		||||
| 
						 | 
				
			
			@ -46,7 +53,6 @@ def pretrain_cli(
 | 
			
		|||
    """
 | 
			
		||||
    pretrain(
 | 
			
		||||
        texts_loc,
 | 
			
		||||
        vectors_model,
 | 
			
		||||
        output_dir,
 | 
			
		||||
        config_path,
 | 
			
		||||
        use_gpu=use_gpu,
 | 
			
		||||
| 
						 | 
				
			
			@ -57,15 +63,16 @@ def pretrain_cli(
 | 
			
		|||
 | 
			
		||||
def pretrain(
 | 
			
		||||
    texts_loc: Path,
 | 
			
		||||
    vectors_model: str,
 | 
			
		||||
    output_dir: Path,
 | 
			
		||||
    config_path: Path,
 | 
			
		||||
    use_gpu: int = -1,
 | 
			
		||||
    resume_path: Optional[Path] = None,
 | 
			
		||||
    epoch_resume: Optional[int] = None,
 | 
			
		||||
):
 | 
			
		||||
    if not config_path or not config_path.exists():
 | 
			
		||||
        msg.fail("Config file not found", config_path, exits=1)
 | 
			
		||||
    verify_cli_args(**locals())
 | 
			
		||||
    if not output_dir.exists():
 | 
			
		||||
        output_dir.mkdir()
 | 
			
		||||
        msg.good(f"Created output directory: {output_dir}")
 | 
			
		||||
 | 
			
		||||
    if use_gpu >= 0:
 | 
			
		||||
        msg.info("Using GPU")
 | 
			
		||||
| 
						 | 
				
			
			@ -76,82 +83,35 @@ def pretrain(
 | 
			
		|||
    msg.info(f"Loading config from: {config_path}")
 | 
			
		||||
    config = util.load_config(config_path, create_objects=False)
 | 
			
		||||
    util.fix_random_seed(config["pretraining"]["seed"])
 | 
			
		||||
    if config["pretraining"]["use_pytorch_for_gpu_memory"]:
 | 
			
		||||
    if use_gpu >= 0 and config["pretraining"]["use_pytorch_for_gpu_memory"]:
 | 
			
		||||
        use_pytorch_for_gpu_memory()
 | 
			
		||||
 | 
			
		||||
    if output_dir.exists() and [p for p in output_dir.iterdir()]:
 | 
			
		||||
        if resume_path:
 | 
			
		||||
            msg.warn(
 | 
			
		||||
                "Output directory is not empty. ",
 | 
			
		||||
                "If you're resuming a run from a previous model in this directory, "
 | 
			
		||||
                "the old models for the consecutive epochs will be overwritten "
 | 
			
		||||
                "with the new ones.",
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            msg.warn(
 | 
			
		||||
                "Output directory is not empty. ",
 | 
			
		||||
                "It is better to use an empty directory or refer to a new output path, "
 | 
			
		||||
                "then the new directory will be created for you.",
 | 
			
		||||
            )
 | 
			
		||||
    if not output_dir.exists():
 | 
			
		||||
        output_dir.mkdir()
 | 
			
		||||
        msg.good(f"Created output directory: {output_dir}")
 | 
			
		||||
    nlp_config = config["nlp"]
 | 
			
		||||
    srsly.write_json(output_dir / "config.json", config)
 | 
			
		||||
    msg.good("Saved config file in the output directory")
 | 
			
		||||
 | 
			
		||||
    config = util.load_config(config_path, create_objects=True)
 | 
			
		||||
    nlp = util.load_model_from_config(nlp_config)
 | 
			
		||||
    pretrain_config = config["pretraining"]
 | 
			
		||||
 | 
			
		||||
    # Load texts from file or stdin
 | 
			
		||||
    if texts_loc != "-":  # reading from a file
 | 
			
		||||
        texts_loc = Path(texts_loc)
 | 
			
		||||
        if not texts_loc.exists():
 | 
			
		||||
            msg.fail("Input text file doesn't exist", texts_loc, exits=1)
 | 
			
		||||
        with msg.loading("Loading input texts..."):
 | 
			
		||||
            texts = list(srsly.read_jsonl(texts_loc))
 | 
			
		||||
        if not texts:
 | 
			
		||||
            msg.fail("Input file is empty", texts_loc, exits=1)
 | 
			
		||||
        msg.good("Loaded input texts")
 | 
			
		||||
        random.shuffle(texts)
 | 
			
		||||
    else:  # reading from stdin
 | 
			
		||||
        msg.info("Reading input text from stdin...")
 | 
			
		||||
        texts = srsly.read_jsonl("-")
 | 
			
		||||
 | 
			
		||||
    with msg.loading(f"Loading model '{vectors_model}'..."):
 | 
			
		||||
        nlp = util.load_model(vectors_model)
 | 
			
		||||
    msg.good(f"Loaded model '{vectors_model}'")
 | 
			
		||||
    tok2vec_path = pretrain_config["tok2vec_model"]
 | 
			
		||||
    tok2vec = config
 | 
			
		||||
    for subpath in tok2vec_path.split("."):
 | 
			
		||||
        tok2vec = tok2vec.get(subpath)
 | 
			
		||||
    model = create_pretraining_model(nlp, tok2vec)
 | 
			
		||||
    model = create_pretraining_model(nlp, tok2vec, pretrain_config)
 | 
			
		||||
    optimizer = pretrain_config["optimizer"]
 | 
			
		||||
 | 
			
		||||
    # Load in pretrained weights to resume from
 | 
			
		||||
    if resume_path is not None:
 | 
			
		||||
        msg.info(f"Resume training tok2vec from: {resume_path}")
 | 
			
		||||
        with resume_path.open("rb") as file_:
 | 
			
		||||
            weights_data = file_.read()
 | 
			
		||||
            model.get_ref("tok2vec").from_bytes(weights_data)
 | 
			
		||||
        # Parse the epoch number from the given weight file
 | 
			
		||||
        model_name = re.search(r"model\d+\.bin", str(resume_path))
 | 
			
		||||
        if model_name:
 | 
			
		||||
            # Default weight file name so read epoch_start from it by cutting off 'model' and '.bin'
 | 
			
		||||
            epoch_resume = int(model_name.group(0)[5:][:-4]) + 1
 | 
			
		||||
            msg.info(f"Resuming from epoch: {epoch_resume}")
 | 
			
		||||
        else:
 | 
			
		||||
            if not epoch_resume:
 | 
			
		||||
                msg.fail(
 | 
			
		||||
                    "You have to use the --epoch-resume setting when using a renamed weight file for --resume-path",
 | 
			
		||||
                    exits=True,
 | 
			
		||||
                )
 | 
			
		||||
            elif epoch_resume < 0:
 | 
			
		||||
                msg.fail(
 | 
			
		||||
                    f"The argument --epoch-resume has to be greater or equal to 0. {epoch_resume} is invalid",
 | 
			
		||||
                    exits=True,
 | 
			
		||||
                )
 | 
			
		||||
            else:
 | 
			
		||||
                msg.info(f"Resuming from epoch: {epoch_resume}")
 | 
			
		||||
        _resume_model(model, resume_path, epoch_resume)
 | 
			
		||||
    else:
 | 
			
		||||
        # Without '--resume-path' the '--epoch-resume' argument is ignored
 | 
			
		||||
        epoch_resume = 0
 | 
			
		||||
| 
						 | 
				
			
			@ -176,7 +136,7 @@ def pretrain(
 | 
			
		|||
                file_.write(srsly.json_dumps(log) + "\n")
 | 
			
		||||
 | 
			
		||||
    skip_counter = 0
 | 
			
		||||
    loss_func = pretrain_config["loss_func"]
 | 
			
		||||
    objective = create_objective(pretrain_config["objective"])
 | 
			
		||||
    for epoch in range(epoch_resume, pretrain_config["max_epochs"]):
 | 
			
		||||
        batches = util.minibatch_by_words(texts, size=pretrain_config["batch_size"])
 | 
			
		||||
        for batch_id, batch in enumerate(batches):
 | 
			
		||||
| 
						 | 
				
			
			@ -187,7 +147,7 @@ def pretrain(
 | 
			
		|||
                min_length=pretrain_config["min_length"],
 | 
			
		||||
            )
 | 
			
		||||
            skip_counter += count
 | 
			
		||||
            loss = make_update(model, docs, optimizer, distance=loss_func)
 | 
			
		||||
            loss = make_update(model, docs, optimizer, objective)
 | 
			
		||||
            progress = tracker.update(epoch, loss, docs)
 | 
			
		||||
            if progress:
 | 
			
		||||
                msg.row(progress, **row_settings)
 | 
			
		||||
| 
						 | 
				
			
			@ -207,7 +167,22 @@ def pretrain(
 | 
			
		|||
    msg.good("Successfully finished pretrain")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def make_update(model, docs, optimizer, distance):
 | 
			
		||||
def _resume_model(model, resume_path, epoch_resume):
 | 
			
		||||
    msg.info(f"Resume training tok2vec from: {resume_path}")
 | 
			
		||||
    with resume_path.open("rb") as file_:
 | 
			
		||||
        weights_data = file_.read()
 | 
			
		||||
        model.get_ref("tok2vec").from_bytes(weights_data)
 | 
			
		||||
    # Parse the epoch number from the given weight file
 | 
			
		||||
    model_name = re.search(r"model\d+\.bin", str(resume_path))
 | 
			
		||||
    if model_name:
 | 
			
		||||
        # Default weight file name so read epoch_start from it by cutting off 'model' and '.bin'
 | 
			
		||||
        epoch_resume = int(model_name.group(0)[5:][:-4]) + 1
 | 
			
		||||
        msg.info(f"Resuming from epoch: {epoch_resume}")
 | 
			
		||||
    else:
 | 
			
		||||
        msg.info(f"Resuming from epoch: {epoch_resume}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def make_update(model, docs, optimizer, objective_func):
 | 
			
		||||
    """Perform an update over a single batch of documents.
 | 
			
		||||
 | 
			
		||||
    docs (iterable): A batch of `Doc` objects.
 | 
			
		||||
| 
						 | 
				
			
			@ -215,7 +190,7 @@ def make_update(model, docs, optimizer, distance):
 | 
			
		|||
    RETURNS loss: A float for the loss.
 | 
			
		||||
    """
 | 
			
		||||
    predictions, backprop = model.begin_update(docs)
 | 
			
		||||
    loss, gradients = get_vectors_loss(model.ops, docs, predictions, distance)
 | 
			
		||||
    loss, gradients = objective_func(model.ops, docs, predictions)
 | 
			
		||||
    backprop(gradients)
 | 
			
		||||
    model.finish_update(optimizer)
 | 
			
		||||
    # Don't want to return a cupy object here
 | 
			
		||||
| 
						 | 
				
			
			@ -254,13 +229,38 @@ def make_docs(nlp, batch, min_length, max_length):
 | 
			
		|||
    return docs, skip_count
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_vectors_loss(ops, docs, prediction, distance):
 | 
			
		||||
    """Compute a mean-squared error loss between the documents' vectors and
 | 
			
		||||
    the prediction.
 | 
			
		||||
def create_objective(config):
 | 
			
		||||
    """Create the objective for pretraining.
 | 
			
		||||
    
 | 
			
		||||
    We'd like to replace this with a registry function but it's tricky because
 | 
			
		||||
    we're also making a model choice based on this. For now we hard-code support
 | 
			
		||||
    for two types (characters, vectors). For characters you can specify
 | 
			
		||||
    n_characters, for vectors you can specify the loss.
 | 
			
		||||
    
 | 
			
		||||
    Bleh.
 | 
			
		||||
    """
 | 
			
		||||
    objective_type = config["type"]
 | 
			
		||||
    if objective_type == "characters":
 | 
			
		||||
        return partial(get_characters_loss, nr_char=config["n_characters"])
 | 
			
		||||
    elif objective_type == "vectors":
 | 
			
		||||
        if config["loss"] == "cosine":
 | 
			
		||||
            return partial(
 | 
			
		||||
                get_vectors_loss,
 | 
			
		||||
                distance=CosineDistance(normalize=True, ignore_zeros=True),
 | 
			
		||||
            )
 | 
			
		||||
        elif config["loss"] == "L2":
 | 
			
		||||
            return partial(
 | 
			
		||||
                get_vectors_loss, distance=L2Distance(normalize=True, ignore_zeros=True)
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            raise ValueError("Unexpected loss type", config["loss"])
 | 
			
		||||
    else:
 | 
			
		||||
        raise ValueError("Unexpected objective_type", objective_type)
 | 
			
		||||
 | 
			
		||||
    Note that this is ripe for customization! We could compute the vectors
 | 
			
		||||
    in some other word, e.g. with an LSTM language model, or use some other
 | 
			
		||||
    type of objective.
 | 
			
		||||
 | 
			
		||||
def get_vectors_loss(ops, docs, prediction, distance):
 | 
			
		||||
    """Compute a loss based on a distance between the documents' vectors and
 | 
			
		||||
    the prediction.
 | 
			
		||||
    """
 | 
			
		||||
    # The simplest way to implement this would be to vstack the
 | 
			
		||||
    # token.vector values, but that's a bit inefficient, especially on GPU.
 | 
			
		||||
| 
						 | 
				
			
			@ -272,7 +272,19 @@ def get_vectors_loss(ops, docs, prediction, distance):
 | 
			
		|||
    return loss, d_target
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def create_pretraining_model(nlp, tok2vec):
 | 
			
		||||
def get_characters_loss(ops, docs, prediction, nr_char):
 | 
			
		||||
    """Compute a loss based on a number of characters predicted from the docs."""
 | 
			
		||||
    target_ids = numpy.vstack([doc.to_utf8_array(nr_char=nr_char) for doc in docs])
 | 
			
		||||
    target_ids = target_ids.reshape((-1,))
 | 
			
		||||
    target = ops.asarray(to_categorical(target_ids, n_classes=256), dtype="f")
 | 
			
		||||
    target = target.reshape((-1, 256 * nr_char))
 | 
			
		||||
    diff = prediction - target
 | 
			
		||||
    loss = (diff ** 2).sum()
 | 
			
		||||
    d_target = diff / float(prediction.shape[0])
 | 
			
		||||
    return loss, d_target
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def create_pretraining_model(nlp, tok2vec, pretrain_config):
 | 
			
		||||
    """Define a network for the pretraining. We simply add an output layer onto
 | 
			
		||||
    the tok2vec input model. The tok2vec input model needs to be a model that
 | 
			
		||||
    takes a batch of Doc objects (as a list), and returns a list of arrays.
 | 
			
		||||
| 
						 | 
				
			
			@ -280,18 +292,24 @@ def create_pretraining_model(nlp, tok2vec):
 | 
			
		|||
    The actual tok2vec layer is stored as a reference, and only this bit will be
 | 
			
		||||
    serialized to file and read back in when calling the 'train' command.
 | 
			
		||||
    """
 | 
			
		||||
    output_size = nlp.vocab.vectors.data.shape[1]
 | 
			
		||||
    output_layer = chain(
 | 
			
		||||
        Maxout(nO=300, nP=3, normalize=True, dropout=0.0), Linear(output_size)
 | 
			
		||||
    )
 | 
			
		||||
    model = chain(tok2vec, list2array())
 | 
			
		||||
    model = chain(model, output_layer)
 | 
			
		||||
    # TODO
 | 
			
		||||
    maxout_pieces = 3
 | 
			
		||||
    hidden_size = 300
 | 
			
		||||
    if pretrain_config["objective"]["type"] == "vectors":
 | 
			
		||||
        model = build_cloze_multi_task_model(
 | 
			
		||||
            nlp.vocab, tok2vec, hidden_size=hidden_size, maxout_pieces=maxout_pieces
 | 
			
		||||
        )
 | 
			
		||||
    elif pretrain_config["objective"]["type"] == "characters":
 | 
			
		||||
        model = build_cloze_characters_multi_task_model(
 | 
			
		||||
            nlp.vocab,
 | 
			
		||||
            tok2vec,
 | 
			
		||||
            hidden_size=hidden_size,
 | 
			
		||||
            maxout_pieces=maxout_pieces,
 | 
			
		||||
            nr_char=pretrain_config["objective"]["n_characters"],
 | 
			
		||||
        )
 | 
			
		||||
    model.initialize(X=[nlp.make_doc("Give it a doc to infer shapes")])
 | 
			
		||||
    mlm_model = build_masked_language_model(nlp.vocab, model)
 | 
			
		||||
    mlm_model.set_ref("tok2vec", tok2vec)
 | 
			
		||||
    mlm_model.set_ref("output_layer", output_layer)
 | 
			
		||||
    mlm_model.initialize(X=[nlp.make_doc("Give it a doc to infer shapes")])
 | 
			
		||||
    return mlm_model
 | 
			
		||||
    set_dropout_rate(model, pretrain_config["dropout"])
 | 
			
		||||
    return model
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ProgressTracker(object):
 | 
			
		||||
| 
						 | 
				
			
			@ -340,3 +358,53 @@ def _smart_round(figure, width=10, max_decimal=4):
 | 
			
		|||
        n_decimal = min(n_decimal, max_decimal)
 | 
			
		||||
        format_str = "%." + str(n_decimal) + "f"
 | 
			
		||||
        return format_str % figure
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def verify_cli_args(
 | 
			
		||||
    texts_loc, output_dir, config_path, use_gpu, resume_path, epoch_resume
 | 
			
		||||
):
 | 
			
		||||
    if not config_path or not config_path.exists():
 | 
			
		||||
        msg.fail("Config file not found", config_path, exits=1)
 | 
			
		||||
    if output_dir.exists() and [p for p in output_dir.iterdir()]:
 | 
			
		||||
        if resume_path:
 | 
			
		||||
            msg.warn(
 | 
			
		||||
                "Output directory is not empty. ",
 | 
			
		||||
                "If you're resuming a run from a previous model in this directory, "
 | 
			
		||||
                "the old models for the consecutive epochs will be overwritten "
 | 
			
		||||
                "with the new ones.",
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            msg.warn(
 | 
			
		||||
                "Output directory is not empty. ",
 | 
			
		||||
                "It is better to use an empty directory or refer to a new output path, "
 | 
			
		||||
                "then the new directory will be created for you.",
 | 
			
		||||
            )
 | 
			
		||||
    if texts_loc != "-":  # reading from a file
 | 
			
		||||
        texts_loc = Path(texts_loc)
 | 
			
		||||
        if not texts_loc.exists():
 | 
			
		||||
            msg.fail("Input text file doesn't exist", texts_loc, exits=1)
 | 
			
		||||
 | 
			
		||||
        for text in srsly.read_jsonl(texts_loc):
 | 
			
		||||
            break
 | 
			
		||||
        else:
 | 
			
		||||
            msg.fail("Input file is empty", texts_loc, exits=1)
 | 
			
		||||
 | 
			
		||||
    if resume_path is not None:
 | 
			
		||||
        model_name = re.search(r"model\d+\.bin", str(resume_path))
 | 
			
		||||
        if not model_name and not epoch_resume:
 | 
			
		||||
            msg.fail(
 | 
			
		||||
                "You have to use the --epoch-resume setting when using a renamed weight file for --resume-path",
 | 
			
		||||
                exits=True,
 | 
			
		||||
            )
 | 
			
		||||
        elif not model_name and epoch_resume < 0:
 | 
			
		||||
            msg.fail(
 | 
			
		||||
                f"The argument --epoch-resume has to be greater or equal to 0. {epoch_resume} is invalid",
 | 
			
		||||
                exits=True,
 | 
			
		||||
            )
 | 
			
		||||
    config = util.load_config(config_path, create_objects=False)
 | 
			
		||||
    if config["pretraining"]["objective"]["type"] == "vectors":
 | 
			
		||||
        if not config["nlp"]["vectors"]:
 | 
			
		||||
            msg.fail(
 | 
			
		||||
                "Must specify nlp.vectors if pretraining.objective.type is vectors",
 | 
			
		||||
                exits=True
 | 
			
		||||
            )
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,6 +1,7 @@
 | 
			
		|||
import numpy
 | 
			
		||||
 | 
			
		||||
from thinc.api import chain, Maxout, LayerNorm, Softmax, Linear, zero_init, Model
 | 
			
		||||
from thinc.api import MultiSoftmax, list2array
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def build_multi_task_model(tok2vec, maxout_pieces, token_vector_width, nO=None):
 | 
			
		||||
| 
						 | 
				
			
			@ -21,9 +22,10 @@ def build_multi_task_model(tok2vec, maxout_pieces, token_vector_width, nO=None):
 | 
			
		|||
    return model
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def build_cloze_multi_task_model(vocab, tok2vec, maxout_pieces, nO=None):
 | 
			
		||||
def build_cloze_multi_task_model(vocab, tok2vec, maxout_pieces, hidden_size, nO=None):
 | 
			
		||||
    # nO = vocab.vectors.data.shape[1]
 | 
			
		||||
    output_layer = chain(
 | 
			
		||||
        list2array(),
 | 
			
		||||
        Maxout(
 | 
			
		||||
            nO=nO,
 | 
			
		||||
            nI=tok2vec.get_dim("nO"),
 | 
			
		||||
| 
						 | 
				
			
			@ -40,6 +42,22 @@ def build_cloze_multi_task_model(vocab, tok2vec, maxout_pieces, nO=None):
 | 
			
		|||
    return model
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def build_cloze_characters_multi_task_model(
 | 
			
		||||
    vocab, tok2vec, maxout_pieces, hidden_size, nr_char
 | 
			
		||||
):
 | 
			
		||||
    output_layer = chain(
 | 
			
		||||
        list2array(),
 | 
			
		||||
        Maxout(hidden_size, nP=maxout_pieces),
 | 
			
		||||
        LayerNorm(nI=hidden_size),
 | 
			
		||||
        MultiSoftmax([256] * nr_char, nI=hidden_size),
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    model = build_masked_language_model(vocab, chain(tok2vec, output_layer))
 | 
			
		||||
    model.set_ref("tok2vec", tok2vec)
 | 
			
		||||
    model.set_ref("output_layer", output_layer)
 | 
			
		||||
    return model
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def build_masked_language_model(vocab, wrapped_model, mask_prob=0.15):
 | 
			
		||||
    """Convert a model into a BERT-style masked language model"""
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -48,7 +66,7 @@ def build_masked_language_model(vocab, wrapped_model, mask_prob=0.15):
 | 
			
		|||
    def mlm_forward(model, docs, is_train):
 | 
			
		||||
        mask, docs = _apply_mask(docs, random_words, mask_prob=mask_prob)
 | 
			
		||||
        mask = model.ops.asarray(mask).reshape((mask.shape[0], 1))
 | 
			
		||||
        output, backprop = model.get_ref("wrapped-model").begin_update(docs)
 | 
			
		||||
        output, backprop = model.layers[0](docs, is_train)
 | 
			
		||||
 | 
			
		||||
        def mlm_backward(d_output):
 | 
			
		||||
            d_output *= 1 - mask
 | 
			
		||||
| 
						 | 
				
			
			@ -56,8 +74,22 @@ def build_masked_language_model(vocab, wrapped_model, mask_prob=0.15):
 | 
			
		|||
 | 
			
		||||
        return output, mlm_backward
 | 
			
		||||
 | 
			
		||||
    mlm_model = Model("masked-language-model", mlm_forward, layers=[wrapped_model])
 | 
			
		||||
    mlm_model.set_ref("wrapped-model", wrapped_model)
 | 
			
		||||
    def mlm_initialize(model, X=None, Y=None):
 | 
			
		||||
        wrapped = model.layers[0]
 | 
			
		||||
        wrapped.initialize(X=X, Y=Y)
 | 
			
		||||
        for dim in wrapped.dim_names:
 | 
			
		||||
            if wrapped.has_dim(dim):
 | 
			
		||||
                model.set_dim(dim, wrapped.get_dim(dim))
 | 
			
		||||
 | 
			
		||||
    mlm_model = Model(
 | 
			
		||||
        "masked-language-model",
 | 
			
		||||
        mlm_forward,
 | 
			
		||||
        layers=[wrapped_model],
 | 
			
		||||
        init=mlm_initialize,
 | 
			
		||||
        refs={"wrapped": wrapped_model},
 | 
			
		||||
        dims={dim: None for dim in wrapped_model.dim_names},
 | 
			
		||||
    )
 | 
			
		||||
    mlm_model.set_ref("wrapped", wrapped_model)
 | 
			
		||||
 | 
			
		||||
    return mlm_model
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user