mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-10 19:57:17 +03:00
pretrain from config
This commit is contained in:
parent
109bbdab98
commit
ffe0451d09
144
examples/experiments/onto-joint/pretrain.cfg
Normal file
144
examples/experiments/onto-joint/pretrain.cfg
Normal file
|
@ -0,0 +1,144 @@
|
|||
# Training hyper-parameters and additional features.
|
||||
[training]
|
||||
# Whether to train on sequences with 'gold standard' sentence boundaries
|
||||
# and tokens. If you set this to true, take care to ensure your run-time
|
||||
# data is passed in sentence-by-sentence via some prior preprocessing.
|
||||
gold_preproc = false
|
||||
# Limitations on training document length or number of examples.
|
||||
max_length = 0
|
||||
limit = 0
|
||||
# Data augmentation
|
||||
orth_variant_level = 0.0
|
||||
dropout = 0.1
|
||||
# Controls early-stopping. 0 or -1 mean unlimited.
|
||||
patience = 1600
|
||||
max_epochs = 0
|
||||
max_steps = 20000
|
||||
eval_frequency = 400
|
||||
# Other settings
|
||||
seed = 0
|
||||
accumulate_gradient = 1
|
||||
use_pytorch_for_gpu_memory = false
|
||||
# Control how scores are printed and checkpoints are evaluated.
|
||||
scores = ["speed", "tags_acc", "uas", "las", "ents_f"]
|
||||
score_weights = {"las": 0.4, "ents_f": 0.4, "tags_acc": 0.2}
|
||||
# These settings are invalid for the transformer models.
|
||||
init_tok2vec = null
|
||||
vectors = null
|
||||
discard_oversize = false
|
||||
|
||||
[training.batch_size]
|
||||
@schedules = "compounding.v1"
|
||||
start = 1000
|
||||
stop = 1000
|
||||
compound = 1.001
|
||||
|
||||
[training.optimizer]
|
||||
@optimizers = "Adam.v1"
|
||||
beta1 = 0.9
|
||||
beta2 = 0.999
|
||||
L2_is_weight_decay = true
|
||||
L2 = 0.01
|
||||
grad_clip = 1.0
|
||||
use_averages = true
|
||||
eps = 1e-8
|
||||
learn_rate = 0.001
|
||||
|
||||
[pretraining]
|
||||
max_epochs = 100
|
||||
min_length = 5
|
||||
max_length = 500
|
||||
dropout = 0.2
|
||||
n_save_every = null
|
||||
batch_size = 3000
|
||||
|
||||
[pretraining.model]
|
||||
@architectures = "spacy.HashEmbedCNN.v1"
|
||||
pretrained_vectors = ${nlp:vectors}
|
||||
width = 256
|
||||
depth = 6
|
||||
window_size = 1
|
||||
embed_size = 2000
|
||||
maxout_pieces = 3
|
||||
subword_features = true
|
||||
dropout = null
|
||||
|
||||
[pretraining.optimizer]
|
||||
@optimizers = "Adam.v1"
|
||||
beta1 = 0.9
|
||||
beta2 = 0.999
|
||||
L2_is_weight_decay = true
|
||||
L2 = 0.01
|
||||
grad_clip = 1.0
|
||||
use_averages = true
|
||||
eps = 1e-8
|
||||
learn_rate = 0.001
|
||||
|
||||
[pretraining.loss_func]
|
||||
@losses = "CosineDistance.v1"
|
||||
|
||||
[nlp]
|
||||
lang = "en"
|
||||
vectors = ${training:vectors}
|
||||
|
||||
[nlp.pipeline.tok2vec]
|
||||
factory = "tok2vec"
|
||||
|
||||
[nlp.pipeline.senter]
|
||||
factory = "senter"
|
||||
|
||||
[nlp.pipeline.ner]
|
||||
factory = "ner"
|
||||
|
||||
[nlp.pipeline.tagger]
|
||||
factory = "tagger"
|
||||
|
||||
[nlp.pipeline.parser]
|
||||
factory = "parser"
|
||||
|
||||
[nlp.pipeline.senter.model]
|
||||
@architectures = "spacy.Tagger.v1"
|
||||
|
||||
[nlp.pipeline.senter.model.tok2vec]
|
||||
@architectures = "spacy.Tok2VecTensors.v1"
|
||||
width = ${nlp.pipeline.tok2vec.model:width}
|
||||
|
||||
[nlp.pipeline.tagger.model]
|
||||
@architectures = "spacy.Tagger.v1"
|
||||
|
||||
[nlp.pipeline.tagger.model.tok2vec]
|
||||
@architectures = "spacy.Tok2VecTensors.v1"
|
||||
width = ${nlp.pipeline.tok2vec.model:width}
|
||||
|
||||
[nlp.pipeline.parser.model]
|
||||
@architectures = "spacy.TransitionBasedParser.v1"
|
||||
nr_feature_tokens = 8
|
||||
hidden_width = 128
|
||||
maxout_pieces = 3
|
||||
use_upper = false
|
||||
|
||||
[nlp.pipeline.parser.model.tok2vec]
|
||||
@architectures = "spacy.Tok2VecTensors.v1"
|
||||
width = ${nlp.pipeline.tok2vec.model:width}
|
||||
|
||||
[nlp.pipeline.ner.model]
|
||||
@architectures = "spacy.TransitionBasedParser.v1"
|
||||
nr_feature_tokens = 3
|
||||
hidden_width = 128
|
||||
maxout_pieces = 3
|
||||
use_upper = false
|
||||
|
||||
[nlp.pipeline.ner.model.tok2vec]
|
||||
@architectures = "spacy.Tok2VecTensors.v1"
|
||||
width = ${nlp.pipeline.tok2vec.model:width}
|
||||
|
||||
[nlp.pipeline.tok2vec.model]
|
||||
@architectures = "spacy.HashEmbedCNN.v1"
|
||||
pretrained_vectors = ${nlp:vectors}
|
||||
width = 256
|
||||
depth = 6
|
||||
window_size = 1
|
||||
embed_size = 10000
|
||||
maxout_pieces = 3
|
||||
subword_features = true
|
||||
dropout = null
|
|
@ -3,48 +3,36 @@ import numpy
|
|||
import time
|
||||
import re
|
||||
from collections import Counter
|
||||
import plac
|
||||
from pathlib import Path
|
||||
from thinc.api import Linear, Maxout, chain, list2array, prefer_gpu
|
||||
from thinc.api import CosineDistance, L2Distance
|
||||
from thinc.api import Linear, Maxout, chain, list2array
|
||||
from wasabi import msg
|
||||
import srsly
|
||||
from thinc.api import use_pytorch_for_gpu_memory
|
||||
|
||||
from ..gold import Example
|
||||
from ..errors import Errors
|
||||
from ..ml.models.multi_task import build_masked_language_model
|
||||
from ..tokens import Doc
|
||||
from ..attrs import ID, HEAD
|
||||
from ..ml.models.tok2vec import build_Tok2Vec_model
|
||||
from .. import util
|
||||
from ..util import create_default_optimizer
|
||||
from .train import _load_pretrained_tok2vec
|
||||
from ..gold import Example
|
||||
|
||||
|
||||
def pretrain(
|
||||
@plac.annotations(
|
||||
# fmt: off
|
||||
texts_loc: ("Path to JSONL file with raw texts to learn from, with text provided as the key 'text' or tokens as the key 'tokens'", "positional", None, str),
|
||||
vectors_model: ("Name or path to spaCy model with vectors to learn from", "positional", None, str),
|
||||
output_dir: ("Directory to write models to on each epoch", "positional", None, str),
|
||||
width: ("Width of CNN layers", "option", "cw", int) = 96,
|
||||
conv_depth: ("Depth of CNN layers", "option", "cd", int) = 4,
|
||||
bilstm_depth: ("Depth of BiLSTM layers (requires PyTorch)", "option", "lstm", int) = 0,
|
||||
cnn_pieces: ("Maxout size for CNN layers. 1 for Mish", "option", "cP", int) = 3,
|
||||
sa_depth: ("Depth of self-attention layers", "option", "sa", int) = 0,
|
||||
use_chars: ("Whether to use character-based embedding", "flag", "chr", bool) = False,
|
||||
cnn_window: ("Window size for CNN layers", "option", "cW", int) = 1,
|
||||
embed_rows: ("Number of embedding rows", "option", "er", int) = 2000,
|
||||
loss_func: ("Loss function to use for the objective. Either 'L2' or 'cosine'", "option", "L", str) = "cosine",
|
||||
use_vectors: ("Whether to use the static vectors as input features", "flag", "uv") = False,
|
||||
dropout: ("Dropout rate", "option", "d", float) = 0.2,
|
||||
n_iter: ("Number of iterations to pretrain", "option", "i", int) = 1000,
|
||||
batch_size: ("Number of words per training batch", "option", "bs", int) = 3000,
|
||||
max_length: ("Max words per example. Longer examples are discarded", "option", "xw", int) = 500,
|
||||
min_length: ("Min words per example. Shorter examples are discarded", "option", "nw", int) = 5,
|
||||
seed: ("Seed for random number generators", "option", "s", int) = 0,
|
||||
n_save_every: ("Save model every X batches.", "option", "se", int) = None,
|
||||
init_tok2vec: ("Path to pretrained weights for the token-to-vector parts of the models. See 'spacy pretrain'. Experimental.", "option", "t2v", Path) = None,
|
||||
epoch_start: ("The epoch to start counting at. Only relevant when using '--init-tok2vec' and the given weight file has been renamed. Prevents unintended overwriting of existing weight files.", "option", "es", int) = None,
|
||||
texts_loc=("Path to JSONL file with raw texts to learn from, with text provided as the key 'text' or tokens as the key 'tokens'", "positional", None, str),
|
||||
vectors_model=("Name or path to spaCy model with vectors to learn from", "positional", None, str),
|
||||
config_path=("Path to config file", "positional", None, Path),
|
||||
output_dir=("Directory to write models to on each epoch", "positional", None, Path),
|
||||
use_gpu=("Use GPU", "option", "g", int),
|
||||
# fmt: on
|
||||
)
|
||||
def pretrain(
|
||||
texts_loc,
|
||||
vectors_model,
|
||||
config_path,
|
||||
output_dir,
|
||||
use_gpu=-1,
|
||||
):
|
||||
"""
|
||||
Pre-train the 'token-to-vector' (tok2vec) layer of pipeline components,
|
||||
|
@ -58,23 +46,24 @@ def pretrain(
|
|||
However, it's still quite experimental, so your mileage may vary.
|
||||
|
||||
To load the weights back in during 'spacy train', you need to ensure
|
||||
all settings are the same between pretraining and training. The API and
|
||||
errors around this need some improvement.
|
||||
all settings are the same between pretraining and training. Ideally,
|
||||
this is done by using the same config file for both commands.
|
||||
"""
|
||||
config = dict(locals())
|
||||
for key in config:
|
||||
if isinstance(config[key], Path):
|
||||
config[key] = str(config[key])
|
||||
util.fix_random_seed(seed)
|
||||
if not config_path or not config_path.exists():
|
||||
msg.fail("Config file not found", config_path, exits=1)
|
||||
|
||||
has_gpu = prefer_gpu()
|
||||
if has_gpu:
|
||||
import torch
|
||||
if use_gpu >= 0:
|
||||
msg.info("Using GPU")
|
||||
util.use_gpu(use_gpu)
|
||||
else:
|
||||
msg.info("Using CPU")
|
||||
|
||||
torch.set_default_tensor_type("torch.cuda.FloatTensor")
|
||||
msg.info("Using GPU" if has_gpu else "Not using GPU")
|
||||
msg.info(f"Loading config from: {config_path}")
|
||||
config = util.load_config(config_path, create_objects=False)
|
||||
util.fix_random_seed(config["training"]["seed"])
|
||||
if config["training"]["use_pytorch_for_gpu_memory"]:
|
||||
use_pytorch_for_gpu_memory()
|
||||
|
||||
output_dir = Path(output_dir)
|
||||
if output_dir.exists() and [p for p in output_dir.iterdir()]:
|
||||
msg.warn(
|
||||
"Output directory is not empty",
|
||||
|
@ -85,7 +74,10 @@ def pretrain(
|
|||
output_dir.mkdir()
|
||||
msg.good(f"Created output directory: {output_dir}")
|
||||
srsly.write_json(output_dir / "config.json", config)
|
||||
msg.good("Saved settings to config.json")
|
||||
msg.good("Saved config file in the output directory")
|
||||
|
||||
config = util.load_config(config_path, create_objects=True)
|
||||
pretrain_config = config["pretraining"]
|
||||
|
||||
# Load texts from file or stdin
|
||||
if texts_loc != "-": # reading from a file
|
||||
|
@ -105,49 +97,11 @@ def pretrain(
|
|||
with msg.loading(f"Loading model '{vectors_model}'..."):
|
||||
nlp = util.load_model(vectors_model)
|
||||
msg.good(f"Loaded model '{vectors_model}'")
|
||||
pretrained_vectors = None if not use_vectors else nlp.vocab.vectors
|
||||
model = create_pretraining_model(
|
||||
nlp,
|
||||
# TODO: replace with config
|
||||
build_Tok2Vec_model(
|
||||
width,
|
||||
embed_rows,
|
||||
conv_depth=conv_depth,
|
||||
pretrained_vectors=pretrained_vectors,
|
||||
bilstm_depth=bilstm_depth, # Requires PyTorch. Experimental.
|
||||
subword_features=not use_chars, # Set to False for Chinese etc
|
||||
maxout_pieces=cnn_pieces, # If set to 1, use Mish activation.
|
||||
window_size=1,
|
||||
char_embed=False,
|
||||
nM=64,
|
||||
nC=8,
|
||||
),
|
||||
)
|
||||
# Load in pretrained weights
|
||||
if init_tok2vec is not None:
|
||||
components = _load_pretrained_tok2vec(nlp, init_tok2vec)
|
||||
msg.text(f"Loaded pretrained tok2vec for: {components}")
|
||||
# Parse the epoch number from the given weight file
|
||||
model_name = re.search(r"model\d+\.bin", str(init_tok2vec))
|
||||
if model_name:
|
||||
# Default weight file name so read epoch_start from it by cutting off 'model' and '.bin'
|
||||
epoch_start = int(model_name.group(0)[5:][:-4]) + 1
|
||||
else:
|
||||
if not epoch_start:
|
||||
msg.fail(
|
||||
"You have to use the --epoch-start argument when using a renamed weight file for --init-tok2vec",
|
||||
exits=True,
|
||||
)
|
||||
elif epoch_start < 0:
|
||||
msg.fail(
|
||||
f"The argument --epoch-start has to be greater or equal to 0. {epoch_start} is invalid",
|
||||
exits=True,
|
||||
)
|
||||
else:
|
||||
# Without '--init-tok2vec' the '--epoch-start' argument is ignored
|
||||
epoch_start = 0
|
||||
tok2vec = pretrain_config["model"]
|
||||
model = create_pretraining_model(nlp, tok2vec)
|
||||
optimizer = pretrain_config["optimizer"]
|
||||
|
||||
optimizer = create_default_optimizer()
|
||||
epoch_start = 0 # TODO
|
||||
tracker = ProgressTracker(frequency=10000)
|
||||
msg.divider(f"Pre-training tok2vec layer - starting at epoch {epoch_start}")
|
||||
row_settings = {"widths": (3, 10, 10, 6, 4), "aligns": ("r", "r", "r", "r", "r")}
|
||||
|
@ -168,28 +122,25 @@ def pretrain(
|
|||
file_.write(srsly.json_dumps(log) + "\n")
|
||||
|
||||
skip_counter = 0
|
||||
for epoch in range(epoch_start, n_iter + epoch_start):
|
||||
for batch_id, batch in enumerate(
|
||||
util.minibatch_by_words(
|
||||
(Example(doc=text) for text in texts), size=batch_size
|
||||
)
|
||||
):
|
||||
loss_func = pretrain_config["loss_func"]
|
||||
for epoch in range(epoch_start, pretrain_config["max_epochs"]):
|
||||
examples = [Example(doc=text) for text in texts]
|
||||
batches = util.minibatch_by_words(examples, size=pretrain_config["batch_size"])
|
||||
for batch_id, batch in enumerate(batches):
|
||||
docs, count = make_docs(
|
||||
nlp,
|
||||
[text for (text, _) in batch],
|
||||
max_length=max_length,
|
||||
min_length=min_length,
|
||||
[ex.doc for ex in batch],
|
||||
max_length=pretrain_config["max_length"],
|
||||
min_length=pretrain_config["min_length"],
|
||||
)
|
||||
skip_counter += count
|
||||
loss = make_update(
|
||||
model, docs, optimizer, objective=loss_func, drop=dropout
|
||||
)
|
||||
loss = make_update(model, docs, optimizer, distance=loss_func)
|
||||
progress = tracker.update(epoch, loss, docs)
|
||||
if progress:
|
||||
msg.row(progress, **row_settings)
|
||||
if texts_loc == "-" and tracker.words_per_epoch[epoch] >= 10 ** 7:
|
||||
break
|
||||
if n_save_every and (batch_id % n_save_every == 0):
|
||||
if pretrain_config["n_save_every"] and (batch_id % pretrain_config["n_save_every"] == 0):
|
||||
_save_model(epoch, is_temp=True)
|
||||
_save_model(epoch)
|
||||
tracker.epoch_loss = 0.0
|
||||
|
@ -201,17 +152,17 @@ def pretrain(
|
|||
msg.good("Successfully finished pretrain")
|
||||
|
||||
|
||||
def make_update(model, docs, optimizer, drop=0.0, objective="L2"):
|
||||
def make_update(model, docs, optimizer, distance):
|
||||
"""Perform an update over a single batch of documents.
|
||||
|
||||
docs (iterable): A batch of `Doc` objects.
|
||||
drop (float): The dropout rate.
|
||||
optimizer (callable): An optimizer.
|
||||
RETURNS loss: A float for the loss.
|
||||
"""
|
||||
predictions, backprop = model.begin_update(docs, drop=drop)
|
||||
loss, gradients = get_vectors_loss(model.ops, docs, predictions, objective)
|
||||
backprop(gradients, sgd=optimizer)
|
||||
predictions, backprop = model.begin_update(docs)
|
||||
loss, gradients = get_vectors_loss(model.ops, docs, predictions, distance)
|
||||
backprop(gradients)
|
||||
model.finish_update(optimizer)
|
||||
# Don't want to return a cupy object here
|
||||
# The gradients are modified in-place by the BERT MLM,
|
||||
# so we get an accurate loss
|
||||
|
@ -243,12 +194,12 @@ def make_docs(nlp, batch, min_length, max_length):
|
|||
heads = numpy.asarray(heads, dtype="uint64")
|
||||
heads = heads.reshape((len(doc), 1))
|
||||
doc = doc.from_array([HEAD], heads)
|
||||
if len(doc) >= min_length and len(doc) < max_length:
|
||||
if min_length <= len(doc) < max_length:
|
||||
docs.append(doc)
|
||||
return docs, skip_count
|
||||
|
||||
|
||||
def get_vectors_loss(ops, docs, prediction, objective="L2"):
|
||||
def get_vectors_loss(ops, docs, prediction, distance):
|
||||
"""Compute a mean-squared error loss between the documents' vectors and
|
||||
the prediction.
|
||||
|
||||
|
@ -262,13 +213,6 @@ def get_vectors_loss(ops, docs, prediction, objective="L2"):
|
|||
# and look them up all at once. This prevents data copying.
|
||||
ids = ops.flatten([doc.to_array(ID).ravel() for doc in docs])
|
||||
target = docs[0].vocab.vectors.data[ids]
|
||||
# TODO: this code originally didn't normalize, but shouldn't normalize=True ?
|
||||
if objective == "L2":
|
||||
distance = L2Distance(normalize=False)
|
||||
elif objective == "cosine":
|
||||
distance = CosineDistance(normalize=False)
|
||||
else:
|
||||
raise ValueError(Errors.E142.format(loss_func=objective))
|
||||
d_target, loss = distance(prediction, target)
|
||||
return loss, d_target
|
||||
|
||||
|
@ -281,7 +225,7 @@ def create_pretraining_model(nlp, tok2vec):
|
|||
"""
|
||||
output_size = nlp.vocab.vectors.data.shape[1]
|
||||
output_layer = chain(
|
||||
Maxout(300, pieces=3, normalize=True, dropout=0.0), Linear(output_size)
|
||||
Maxout(nO=300, nP=3, normalize=True, dropout=0.0), Linear(output_size)
|
||||
)
|
||||
# This is annoying, but the parser etc have the flatten step after
|
||||
# the tok2vec. To load the weights in cleanly, we need to match
|
||||
|
@ -289,11 +233,12 @@ def create_pretraining_model(nlp, tok2vec):
|
|||
# "tok2vec" has to be the same set of processes as what the components do.
|
||||
tok2vec = chain(tok2vec, list2array())
|
||||
model = chain(tok2vec, output_layer)
|
||||
model = build_masked_language_model(nlp.vocab, model)
|
||||
model.set_ref("tok2vec", tok2vec)
|
||||
model.set_ref("output_layer", output_layer)
|
||||
model.initialize(X=[nlp.make_doc("Give it a doc to infer shapes")])
|
||||
return model
|
||||
mlm_model = build_masked_language_model(nlp.vocab, model)
|
||||
mlm_model.set_ref("tok2vec", tok2vec)
|
||||
mlm_model.set_ref("output_layer", output_layer)
|
||||
mlm_model.initialize(X=[nlp.make_doc("Give it a doc to infer shapes")])
|
||||
return mlm_model
|
||||
|
||||
|
||||
class ProgressTracker(object):
|
||||
|
|
|
@ -441,8 +441,6 @@ class Errors(object):
|
|||
"should be of equal length.")
|
||||
E141 = ("Entity vectors should be of length {required} instead of the "
|
||||
"provided {found}.")
|
||||
E142 = ("Unsupported loss_function '{loss_func}'. Use either 'L2' or "
|
||||
"'cosine'.")
|
||||
E143 = ("Labels for component '{name}' not initialized. Did you forget to "
|
||||
"call add_label()?")
|
||||
E144 = ("Could not find parameter `{param}` when building the entity "
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
from thinc.api import chain, Maxout, LayerNorm, Softmax, Linear, zero_init
|
||||
import numpy
|
||||
|
||||
from thinc.api import chain, Maxout, LayerNorm, Softmax, Linear, zero_init, Model
|
||||
|
||||
|
||||
def build_multi_task_model(n_tags, tok2vec=None, token_vector_width=96):
|
||||
|
@ -24,6 +26,80 @@ def build_cloze_multi_task_model(vocab, tok2vec):
|
|||
return model
|
||||
|
||||
|
||||
def build_masked_language_model(*args, **kwargs):
|
||||
# TODO cf https://github.com/explosion/spaCy/blob/2c107f02a4d60bda2440db0aad1a88cbbf4fb52d/spacy/_ml.py#L828
|
||||
raise NotImplementedError
|
||||
def build_masked_language_model(vocab, wrapped_model, mask_prob=0.15):
|
||||
"""Convert a model into a BERT-style masked language model"""
|
||||
|
||||
random_words = _RandomWords(vocab)
|
||||
|
||||
def mlm_forward(model, docs, is_train):
|
||||
mask, docs = _apply_mask(docs, random_words, mask_prob=mask_prob)
|
||||
mask = model.ops.asarray(mask).reshape((mask.shape[0], 1))
|
||||
output, backprop = model.get_ref("wrapped-model").begin_update(docs) # drop=drop
|
||||
|
||||
def mlm_backward(d_output):
|
||||
d_output *= 1 - mask
|
||||
return backprop(d_output)
|
||||
|
||||
return output, mlm_backward
|
||||
|
||||
mlm_model = Model("masked-language-model", mlm_forward, layers=[wrapped_model])
|
||||
mlm_model.set_ref("wrapped-model", wrapped_model)
|
||||
|
||||
return mlm_model
|
||||
|
||||
|
||||
class _RandomWords(object):
|
||||
def __init__(self, vocab):
|
||||
self.words = [lex.text for lex in vocab if lex.prob != 0.0]
|
||||
self.probs = [lex.prob for lex in vocab if lex.prob != 0.0]
|
||||
self.words = self.words[:10000]
|
||||
self.probs = self.probs[:10000]
|
||||
self.probs = numpy.exp(numpy.array(self.probs, dtype="f"))
|
||||
self.probs /= self.probs.sum()
|
||||
self._cache = []
|
||||
|
||||
def next(self):
|
||||
if not self._cache:
|
||||
self._cache.extend(
|
||||
numpy.random.choice(len(self.words), 10000, p=self.probs)
|
||||
)
|
||||
index = self._cache.pop()
|
||||
return self.words[index]
|
||||
|
||||
|
||||
def _apply_mask(docs, random_words, mask_prob=0.15):
|
||||
# This needs to be here to avoid circular imports
|
||||
from ...tokens import Doc
|
||||
|
||||
N = sum(len(doc) for doc in docs)
|
||||
mask = numpy.random.uniform(0.0, 1.0, (N,))
|
||||
mask = mask >= mask_prob
|
||||
i = 0
|
||||
masked_docs = []
|
||||
for doc in docs:
|
||||
words = []
|
||||
for token in doc:
|
||||
if not mask[i]:
|
||||
word = _replace_word(token.text, random_words)
|
||||
else:
|
||||
word = token.text
|
||||
words.append(word)
|
||||
i += 1
|
||||
spaces = [bool(w.whitespace_) for w in doc]
|
||||
# NB: If you change this implementation to instead modify
|
||||
# the docs in place, take care that the IDs reflect the original
|
||||
# words. Currently we use the original docs to make the vectors
|
||||
# for the target, so we don't lose the original tokens. But if
|
||||
# you modified the docs in place here, you would.
|
||||
masked_docs.append(Doc(doc.vocab, words=words, spaces=spaces))
|
||||
return mask, masked_docs
|
||||
|
||||
|
||||
def _replace_word(word, random_words, mask="[MASK]"):
|
||||
roll = numpy.random.random()
|
||||
if roll < 0.8:
|
||||
return mask
|
||||
elif roll < 0.9:
|
||||
return random_words.next()
|
||||
else:
|
||||
return word
|
Loading…
Reference in New Issue
Block a user