mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-04 21:50:35 +03:00
Merge pull request #5543 from svlandeg/feature/pretrain-config
pretrain from config
This commit is contained in:
commit
8411d4f4e6
|
@ -25,6 +25,7 @@ score_weights = {"las": 0.4, "ents_f": 0.4, "tags_acc": 0.2}
|
||||||
# These settings are invalid for the transformer models.
|
# These settings are invalid for the transformer models.
|
||||||
init_tok2vec = null
|
init_tok2vec = null
|
||||||
vectors = null
|
vectors = null
|
||||||
|
discard_oversize = false
|
||||||
|
|
||||||
[training.batch_size]
|
[training.batch_size]
|
||||||
@schedules = "compounding.v1"
|
@schedules = "compounding.v1"
|
||||||
|
@ -32,7 +33,7 @@ start = 1000
|
||||||
stop = 1000
|
stop = 1000
|
||||||
compound = 1.001
|
compound = 1.001
|
||||||
|
|
||||||
[optimizer]
|
[training.optimizer]
|
||||||
@optimizers = "Adam.v1"
|
@optimizers = "Adam.v1"
|
||||||
beta1 = 0.9
|
beta1 = 0.9
|
||||||
beta2 = 0.999
|
beta2 = 0.999
|
||||||
|
@ -113,3 +114,4 @@ window_size = 1
|
||||||
embed_size = 10000
|
embed_size = 10000
|
||||||
maxout_pieces = 3
|
maxout_pieces = 3
|
||||||
subword_features = true
|
subword_features = true
|
||||||
|
dropout = null
|
||||||
|
|
137
examples/experiments/onto-joint/pretrain.cfg
Normal file
137
examples/experiments/onto-joint/pretrain.cfg
Normal file
|
@ -0,0 +1,137 @@
|
||||||
|
# Training hyper-parameters and additional features.
|
||||||
|
[training]
|
||||||
|
# Whether to train on sequences with 'gold standard' sentence boundaries
|
||||||
|
# and tokens. If you set this to true, take care to ensure your run-time
|
||||||
|
# data is passed in sentence-by-sentence via some prior preprocessing.
|
||||||
|
gold_preproc = false
|
||||||
|
# Limitations on training document length or number of examples.
|
||||||
|
max_length = 0
|
||||||
|
limit = 0
|
||||||
|
# Data augmentation
|
||||||
|
orth_variant_level = 0.0
|
||||||
|
dropout = 0.1
|
||||||
|
# Controls early-stopping. 0 or -1 mean unlimited.
|
||||||
|
patience = 1600
|
||||||
|
max_epochs = 0
|
||||||
|
max_steps = 20000
|
||||||
|
eval_frequency = 400
|
||||||
|
# Other settings
|
||||||
|
seed = 0
|
||||||
|
accumulate_gradient = 1
|
||||||
|
use_pytorch_for_gpu_memory = false
|
||||||
|
# Control how scores are printed and checkpoints are evaluated.
|
||||||
|
scores = ["speed", "tags_acc", "uas", "las", "ents_f"]
|
||||||
|
score_weights = {"las": 0.4, "ents_f": 0.4, "tags_acc": 0.2}
|
||||||
|
# These settings are invalid for the transformer models.
|
||||||
|
init_tok2vec = null
|
||||||
|
vectors = null
|
||||||
|
discard_oversize = false
|
||||||
|
|
||||||
|
[training.batch_size]
|
||||||
|
@schedules = "compounding.v1"
|
||||||
|
start = 1000
|
||||||
|
stop = 1000
|
||||||
|
compound = 1.001
|
||||||
|
|
||||||
|
[training.optimizer]
|
||||||
|
@optimizers = "Adam.v1"
|
||||||
|
beta1 = 0.9
|
||||||
|
beta2 = 0.999
|
||||||
|
L2_is_weight_decay = true
|
||||||
|
L2 = 0.01
|
||||||
|
grad_clip = 1.0
|
||||||
|
use_averages = true
|
||||||
|
eps = 1e-8
|
||||||
|
learn_rate = 0.001
|
||||||
|
|
||||||
|
[pretraining]
|
||||||
|
max_epochs = 1000
|
||||||
|
min_length = 5
|
||||||
|
max_length = 500
|
||||||
|
dropout = 0.2
|
||||||
|
n_save_every = null
|
||||||
|
batch_size = 3000
|
||||||
|
seed = ${training:seed}
|
||||||
|
use_pytorch_for_gpu_memory = ${training:use_pytorch_for_gpu_memory}
|
||||||
|
tok2vec_model = "nlp.pipeline.tok2vec.model"
|
||||||
|
|
||||||
|
[pretraining.optimizer]
|
||||||
|
@optimizers = "Adam.v1"
|
||||||
|
beta1 = 0.9
|
||||||
|
beta2 = 0.999
|
||||||
|
L2_is_weight_decay = true
|
||||||
|
L2 = 0.01
|
||||||
|
grad_clip = 1.0
|
||||||
|
use_averages = true
|
||||||
|
eps = 1e-8
|
||||||
|
learn_rate = 0.001
|
||||||
|
|
||||||
|
[pretraining.loss_func]
|
||||||
|
@losses = "CosineDistance.v1"
|
||||||
|
normalize = true
|
||||||
|
|
||||||
|
[nlp]
|
||||||
|
lang = "en"
|
||||||
|
vectors = ${training:vectors}
|
||||||
|
|
||||||
|
[nlp.pipeline.tok2vec]
|
||||||
|
factory = "tok2vec"
|
||||||
|
|
||||||
|
[nlp.pipeline.senter]
|
||||||
|
factory = "senter"
|
||||||
|
|
||||||
|
[nlp.pipeline.ner]
|
||||||
|
factory = "ner"
|
||||||
|
|
||||||
|
[nlp.pipeline.tagger]
|
||||||
|
factory = "tagger"
|
||||||
|
|
||||||
|
[nlp.pipeline.parser]
|
||||||
|
factory = "parser"
|
||||||
|
|
||||||
|
[nlp.pipeline.senter.model]
|
||||||
|
@architectures = "spacy.Tagger.v1"
|
||||||
|
|
||||||
|
[nlp.pipeline.senter.model.tok2vec]
|
||||||
|
@architectures = "spacy.Tok2VecTensors.v1"
|
||||||
|
width = ${nlp.pipeline.tok2vec.model:width}
|
||||||
|
|
||||||
|
[nlp.pipeline.tagger.model]
|
||||||
|
@architectures = "spacy.Tagger.v1"
|
||||||
|
|
||||||
|
[nlp.pipeline.tagger.model.tok2vec]
|
||||||
|
@architectures = "spacy.Tok2VecTensors.v1"
|
||||||
|
width = ${nlp.pipeline.tok2vec.model:width}
|
||||||
|
|
||||||
|
[nlp.pipeline.parser.model]
|
||||||
|
@architectures = "spacy.TransitionBasedParser.v1"
|
||||||
|
nr_feature_tokens = 8
|
||||||
|
hidden_width = 128
|
||||||
|
maxout_pieces = 3
|
||||||
|
use_upper = false
|
||||||
|
|
||||||
|
[nlp.pipeline.parser.model.tok2vec]
|
||||||
|
@architectures = "spacy.Tok2VecTensors.v1"
|
||||||
|
width = ${nlp.pipeline.tok2vec.model:width}
|
||||||
|
|
||||||
|
[nlp.pipeline.ner.model]
|
||||||
|
@architectures = "spacy.TransitionBasedParser.v1"
|
||||||
|
nr_feature_tokens = 3
|
||||||
|
hidden_width = 128
|
||||||
|
maxout_pieces = 3
|
||||||
|
use_upper = false
|
||||||
|
|
||||||
|
[nlp.pipeline.ner.model.tok2vec]
|
||||||
|
@architectures = "spacy.Tok2VecTensors.v1"
|
||||||
|
width = ${nlp.pipeline.tok2vec.model:width}
|
||||||
|
|
||||||
|
[nlp.pipeline.tok2vec.model]
|
||||||
|
@architectures = "spacy.HashEmbedCNN.v1"
|
||||||
|
pretrained_vectors = ${nlp:vectors}
|
||||||
|
width = 256
|
||||||
|
depth = 6
|
||||||
|
window_size = 1
|
||||||
|
embed_size = 10000
|
||||||
|
maxout_pieces = 3
|
||||||
|
subword_features = true
|
||||||
|
dropout = null
|
|
@ -14,6 +14,7 @@ score_weights = {"las": 0.8, "tags_acc": 0.2}
|
||||||
limit = 0
|
limit = 0
|
||||||
seed = 0
|
seed = 0
|
||||||
accumulate_gradient = 2
|
accumulate_gradient = 2
|
||||||
|
discard_oversize = false
|
||||||
|
|
||||||
[training.batch_size]
|
[training.batch_size]
|
||||||
@schedules = "compounding.v1"
|
@schedules = "compounding.v1"
|
||||||
|
@ -21,7 +22,7 @@ start = 100
|
||||||
stop = 1000
|
stop = 1000
|
||||||
compound = 1.001
|
compound = 1.001
|
||||||
|
|
||||||
[optimizer]
|
[training.optimizer]
|
||||||
@optimizers = "Adam.v1"
|
@optimizers = "Adam.v1"
|
||||||
learn_rate = 0.001
|
learn_rate = 0.001
|
||||||
beta1 = 0.9
|
beta1 = 0.9
|
||||||
|
@ -65,3 +66,4 @@ depth = 4
|
||||||
embed_size = 2000
|
embed_size = 2000
|
||||||
subword_features = true
|
subword_features = true
|
||||||
maxout_pieces = 3
|
maxout_pieces = 3
|
||||||
|
dropout = null
|
||||||
|
|
|
@ -14,6 +14,7 @@ score_weights = {"las": 0.8, "tags_acc": 0.2}
|
||||||
limit = 0
|
limit = 0
|
||||||
seed = 0
|
seed = 0
|
||||||
accumulate_gradient = 2
|
accumulate_gradient = 2
|
||||||
|
discard_oversize = false
|
||||||
|
|
||||||
[training.batch_size]
|
[training.batch_size]
|
||||||
@schedules = "compounding.v1"
|
@schedules = "compounding.v1"
|
||||||
|
@ -21,7 +22,7 @@ start = 100
|
||||||
stop = 1000
|
stop = 1000
|
||||||
compound = 1.001
|
compound = 1.001
|
||||||
|
|
||||||
[optimizer]
|
[training.optimizer]
|
||||||
@optimizers = "Adam.v1"
|
@optimizers = "Adam.v1"
|
||||||
learn_rate = 0.001
|
learn_rate = 0.001
|
||||||
beta1 = 0.9
|
beta1 = 0.9
|
||||||
|
@ -66,3 +67,4 @@ window_size = 1
|
||||||
embed_size = 2000
|
embed_size = 2000
|
||||||
maxout_pieces = 3
|
maxout_pieces = 3
|
||||||
subword_features = true
|
subword_features = true
|
||||||
|
dropout = null
|
||||||
|
|
|
@ -12,8 +12,9 @@ max_length = 0
|
||||||
batch_size = 25
|
batch_size = 25
|
||||||
seed = 0
|
seed = 0
|
||||||
accumulate_gradient = 2
|
accumulate_gradient = 2
|
||||||
|
discard_oversize = false
|
||||||
|
|
||||||
[optimizer]
|
[training.optimizer]
|
||||||
@optimizers = "Adam.v1"
|
@optimizers = "Adam.v1"
|
||||||
learn_rate = 0.001
|
learn_rate = 0.001
|
||||||
beta1 = 0.9
|
beta1 = 0.9
|
||||||
|
@ -36,6 +37,7 @@ nM = 64
|
||||||
nC = 8
|
nC = 8
|
||||||
rows = 2000
|
rows = 2000
|
||||||
columns = ["ID", "NORM", "PREFIX", "SUFFIX", "SHAPE", "ORTH"]
|
columns = ["ID", "NORM", "PREFIX", "SUFFIX", "SHAPE", "ORTH"]
|
||||||
|
dropout = null
|
||||||
|
|
||||||
[nlp.pipeline.tok2vec.model.extract.features]
|
[nlp.pipeline.tok2vec.model.extract.features]
|
||||||
@architectures = "spacy.Doc2Feats.v1"
|
@architectures = "spacy.Doc2Feats.v1"
|
||||||
|
|
|
@ -11,6 +11,7 @@ gold_preproc = true
|
||||||
max_length = 0
|
max_length = 0
|
||||||
seed = 0
|
seed = 0
|
||||||
accumulate_gradient = 2
|
accumulate_gradient = 2
|
||||||
|
discard_oversize = false
|
||||||
|
|
||||||
[training.batch_size]
|
[training.batch_size]
|
||||||
@schedules = "compounding.v1"
|
@schedules = "compounding.v1"
|
||||||
|
@ -19,7 +20,7 @@ stop = 3000
|
||||||
compound = 1.001
|
compound = 1.001
|
||||||
|
|
||||||
|
|
||||||
[optimizer]
|
[training.optimizer]
|
||||||
@optimizers = "Adam.v1"
|
@optimizers = "Adam.v1"
|
||||||
learn_rate = 0.001
|
learn_rate = 0.001
|
||||||
beta1 = 0.9
|
beta1 = 0.9
|
||||||
|
@ -44,3 +45,4 @@ maxout_pieces = 3
|
||||||
window_size = 1
|
window_size = 1
|
||||||
subword_features = true
|
subword_features = true
|
||||||
pretrained_vectors = null
|
pretrained_vectors = null
|
||||||
|
dropout = null
|
||||||
|
|
|
@ -1,212 +0,0 @@
|
||||||
"""This script is experimental.
|
|
||||||
|
|
||||||
Try pre-training the CNN component of the text categorizer using a cheap
|
|
||||||
language modelling-like objective. Specifically, we load pretrained vectors
|
|
||||||
(from something like word2vec, GloVe, FastText etc), and use the CNN to
|
|
||||||
predict the tokens' pretrained vectors. This isn't as easy as it sounds:
|
|
||||||
we're not merely doing compression here, because heavy dropout is applied,
|
|
||||||
including over the input words. This means the model must often (50% of the time)
|
|
||||||
use the context in order to predict the word.
|
|
||||||
|
|
||||||
To evaluate the technique, we're pre-training with the 50k texts from the IMDB
|
|
||||||
corpus, and then training with only 100 labels. Note that it's a bit dirty to
|
|
||||||
pre-train with the development data, but also not *so* terrible: we're not using
|
|
||||||
the development labels, after all --- only the unlabelled text.
|
|
||||||
"""
|
|
||||||
import plac
|
|
||||||
import tqdm
|
|
||||||
import random
|
|
||||||
|
|
||||||
import ml_datasets
|
|
||||||
|
|
||||||
import spacy
|
|
||||||
from spacy.util import minibatch
|
|
||||||
from spacy.pipeline import TextCategorizer
|
|
||||||
from spacy.ml.models.tok2vec import build_Tok2Vec_model
|
|
||||||
import numpy
|
|
||||||
|
|
||||||
|
|
||||||
def load_texts(limit=0):
|
|
||||||
train, dev = ml_datasets.imdb()
|
|
||||||
train_texts, train_labels = zip(*train)
|
|
||||||
dev_texts, dev_labels = zip(*train)
|
|
||||||
train_texts = list(train_texts)
|
|
||||||
dev_texts = list(dev_texts)
|
|
||||||
random.shuffle(train_texts)
|
|
||||||
random.shuffle(dev_texts)
|
|
||||||
if limit >= 1:
|
|
||||||
return train_texts[:limit]
|
|
||||||
else:
|
|
||||||
return list(train_texts) + list(dev_texts)
|
|
||||||
|
|
||||||
|
|
||||||
def load_textcat_data(limit=0):
|
|
||||||
"""Load data from the IMDB dataset."""
|
|
||||||
# Partition off part of the train data for evaluation
|
|
||||||
train_data, eval_data = ml_datasets.imdb()
|
|
||||||
random.shuffle(train_data)
|
|
||||||
train_data = train_data[-limit:]
|
|
||||||
texts, labels = zip(*train_data)
|
|
||||||
eval_texts, eval_labels = zip(*eval_data)
|
|
||||||
cats = [{"POSITIVE": bool(y), "NEGATIVE": not bool(y)} for y in labels]
|
|
||||||
eval_cats = [{"POSITIVE": bool(y), "NEGATIVE": not bool(y)} for y in eval_labels]
|
|
||||||
return (texts, cats), (eval_texts, eval_cats)
|
|
||||||
|
|
||||||
|
|
||||||
def prefer_gpu():
|
|
||||||
used = spacy.util.use_gpu(0)
|
|
||||||
if used is None:
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
import cupy.random
|
|
||||||
|
|
||||||
cupy.random.seed(0)
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def build_textcat_model(tok2vec, nr_class, width):
|
|
||||||
from thinc.api import Model, Softmax, chain, reduce_mean, list2ragged
|
|
||||||
|
|
||||||
with Model.define_operators({">>": chain}):
|
|
||||||
model = (
|
|
||||||
tok2vec
|
|
||||||
>> list2ragged()
|
|
||||||
>> reduce_mean()
|
|
||||||
>> Softmax(nr_class, width)
|
|
||||||
)
|
|
||||||
model.set_ref("tok2vec", tok2vec)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def block_gradients(model):
|
|
||||||
from thinc.api import wrap # TODO FIX
|
|
||||||
|
|
||||||
def forward(X, drop=0.0):
|
|
||||||
Y, _ = model.begin_update(X, drop=drop)
|
|
||||||
return Y, None
|
|
||||||
|
|
||||||
return wrap(forward, model)
|
|
||||||
|
|
||||||
|
|
||||||
def create_pipeline(width, embed_size, vectors_model):
|
|
||||||
print("Load vectors")
|
|
||||||
nlp = spacy.load(vectors_model)
|
|
||||||
print("Start training")
|
|
||||||
textcat = TextCategorizer(
|
|
||||||
nlp.vocab,
|
|
||||||
labels=["POSITIVE", "NEGATIVE"],
|
|
||||||
# TODO: replace with config version
|
|
||||||
model=build_textcat_model(
|
|
||||||
build_Tok2Vec_model(width=width, embed_size=embed_size), 2, width
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
nlp.add_pipe(textcat)
|
|
||||||
return nlp
|
|
||||||
|
|
||||||
|
|
||||||
def train_tensorizer(nlp, texts, dropout, n_iter):
|
|
||||||
tensorizer = nlp.create_pipe("tensorizer")
|
|
||||||
nlp.add_pipe(tensorizer)
|
|
||||||
optimizer = nlp.begin_training()
|
|
||||||
for i in range(n_iter):
|
|
||||||
losses = {}
|
|
||||||
for i, batch in enumerate(minibatch(tqdm.tqdm(texts))):
|
|
||||||
docs = [nlp.make_doc(text) for text in batch]
|
|
||||||
tensorizer.update((docs, None), losses=losses, sgd=optimizer, drop=dropout)
|
|
||||||
print(losses)
|
|
||||||
return optimizer
|
|
||||||
|
|
||||||
|
|
||||||
def train_textcat(nlp, n_texts, n_iter=10):
|
|
||||||
textcat = nlp.get_pipe("textcat")
|
|
||||||
tok2vec_weights = textcat.model.get_ref("tok2vec").to_bytes()
|
|
||||||
(train_texts, train_cats), (dev_texts, dev_cats) = load_textcat_data(limit=n_texts)
|
|
||||||
print(
|
|
||||||
"Using {} examples ({} training, {} evaluation)".format(
|
|
||||||
n_texts, len(train_texts), len(dev_texts)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
train_data = list(zip(train_texts, [{"cats": cats} for cats in train_cats]))
|
|
||||||
|
|
||||||
with nlp.select_pipes(enable="textcat"): # only train textcat
|
|
||||||
optimizer = nlp.begin_training()
|
|
||||||
textcat.model.get_ref("tok2vec").from_bytes(tok2vec_weights)
|
|
||||||
print("Training the model...")
|
|
||||||
print("{:^5}\t{:^5}\t{:^5}\t{:^5}".format("LOSS", "P", "R", "F"))
|
|
||||||
for i in range(n_iter):
|
|
||||||
losses = {"textcat": 0.0}
|
|
||||||
# batch up the examples using spaCy's minibatch
|
|
||||||
batches = minibatch(tqdm.tqdm(train_data), size=2)
|
|
||||||
for batch in batches:
|
|
||||||
nlp.update(batch, sgd=optimizer, drop=0.2, losses=losses)
|
|
||||||
with textcat.model.use_params(optimizer.averages):
|
|
||||||
# evaluate on the dev data split off in load_data()
|
|
||||||
scores = evaluate_textcat(nlp.tokenizer, textcat, dev_texts, dev_cats)
|
|
||||||
print(
|
|
||||||
"{0:.3f}\t{1:.3f}\t{2:.3f}\t{3:.3f}".format( # print a simple table
|
|
||||||
losses["textcat"],
|
|
||||||
scores["textcat_p"],
|
|
||||||
scores["textcat_r"],
|
|
||||||
scores["textcat_f"],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def evaluate_textcat(tokenizer, textcat, texts, cats):
|
|
||||||
docs = (tokenizer(text) for text in texts)
|
|
||||||
tp = 1e-8
|
|
||||||
fp = 1e-8
|
|
||||||
tn = 1e-8
|
|
||||||
fn = 1e-8
|
|
||||||
for i, doc in enumerate(textcat.pipe(docs)):
|
|
||||||
gold = cats[i]
|
|
||||||
for label, score in doc.cats.items():
|
|
||||||
if label not in gold:
|
|
||||||
continue
|
|
||||||
if score >= 0.5 and gold[label] >= 0.5:
|
|
||||||
tp += 1.0
|
|
||||||
elif score >= 0.5 and gold[label] < 0.5:
|
|
||||||
fp += 1.0
|
|
||||||
elif score < 0.5 and gold[label] < 0.5:
|
|
||||||
tn += 1
|
|
||||||
elif score < 0.5 and gold[label] >= 0.5:
|
|
||||||
fn += 1
|
|
||||||
precision = tp / (tp + fp)
|
|
||||||
recall = tp / (tp + fn)
|
|
||||||
f_score = 2 * (precision * recall) / (precision + recall)
|
|
||||||
return {"textcat_p": precision, "textcat_r": recall, "textcat_f": f_score}
|
|
||||||
|
|
||||||
|
|
||||||
@plac.annotations(
|
|
||||||
width=("Width of CNN layers", "positional", None, int),
|
|
||||||
embed_size=("Embedding rows", "positional", None, int),
|
|
||||||
pretrain_iters=("Number of iterations to pretrain", "option", "pn", int),
|
|
||||||
train_iters=("Number of iterations to pretrain", "option", "tn", int),
|
|
||||||
train_examples=("Number of labelled examples", "option", "eg", int),
|
|
||||||
vectors_model=("Name or path to vectors model to learn from"),
|
|
||||||
)
|
|
||||||
def main(
|
|
||||||
width,
|
|
||||||
embed_size,
|
|
||||||
vectors_model,
|
|
||||||
pretrain_iters=30,
|
|
||||||
train_iters=30,
|
|
||||||
train_examples=1000,
|
|
||||||
):
|
|
||||||
random.seed(0)
|
|
||||||
numpy.random.seed(0)
|
|
||||||
use_gpu = prefer_gpu()
|
|
||||||
print("Using GPU?", use_gpu)
|
|
||||||
|
|
||||||
nlp = create_pipeline(width, embed_size, vectors_model)
|
|
||||||
print("Load data")
|
|
||||||
texts = load_texts(limit=0)
|
|
||||||
print("Train tensorizer")
|
|
||||||
optimizer = train_tensorizer(nlp, texts, dropout=0.2, n_iter=pretrain_iters)
|
|
||||||
print("Train textcat")
|
|
||||||
train_textcat(nlp, train_examples, n_iter=train_iters)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
plac.call(main)
|
|
|
@ -2,16 +2,15 @@ if __name__ == "__main__":
|
||||||
import plac
|
import plac
|
||||||
import sys
|
import sys
|
||||||
from wasabi import msg
|
from wasabi import msg
|
||||||
from spacy.cli import download, link, info, package, train, pretrain, convert
|
from spacy.cli import download, link, info, package, pretrain, convert
|
||||||
from spacy.cli import init_model, profile, evaluate, validate, debug_data
|
from spacy.cli import init_model, profile, evaluate, validate, debug_data
|
||||||
from spacy.cli import train_from_config_cli
|
from spacy.cli import train_cli
|
||||||
|
|
||||||
commands = {
|
commands = {
|
||||||
"download": download,
|
"download": download,
|
||||||
"link": link,
|
"link": link,
|
||||||
"info": info,
|
"info": info,
|
||||||
"train": train,
|
"train": train_cli,
|
||||||
"train-from-config": train_from_config_cli,
|
|
||||||
"pretrain": pretrain,
|
"pretrain": pretrain,
|
||||||
"debug-data": debug_data,
|
"debug-data": debug_data,
|
||||||
"evaluate": evaluate,
|
"evaluate": evaluate,
|
||||||
|
|
|
@ -4,8 +4,7 @@ from .download import download # noqa: F401
|
||||||
from .info import info # noqa: F401
|
from .info import info # noqa: F401
|
||||||
from .package import package # noqa: F401
|
from .package import package # noqa: F401
|
||||||
from .profile import profile # noqa: F401
|
from .profile import profile # noqa: F401
|
||||||
from .train import train # noqa: F401
|
from .train_from_config import train_cli # noqa: F401
|
||||||
from .train_from_config import train_from_config_cli # noqa: F401
|
|
||||||
from .pretrain import pretrain # noqa: F401
|
from .pretrain import pretrain # noqa: F401
|
||||||
from .debug_data import debug_data # noqa: F401
|
from .debug_data import debug_data # noqa: F401
|
||||||
from .evaluate import evaluate # noqa: F401
|
from .evaluate import evaluate # noqa: F401
|
||||||
|
|
|
@ -3,48 +3,39 @@ import numpy
|
||||||
import time
|
import time
|
||||||
import re
|
import re
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
|
import plac
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from thinc.api import Linear, Maxout, chain, list2array, prefer_gpu
|
from thinc.api import Linear, Maxout, chain, list2array, use_pytorch_for_gpu_memory
|
||||||
from thinc.api import CosineDistance, L2Distance
|
|
||||||
from wasabi import msg
|
from wasabi import msg
|
||||||
import srsly
|
import srsly
|
||||||
|
|
||||||
from ..gold import Example
|
|
||||||
from ..errors import Errors
|
from ..errors import Errors
|
||||||
from ..ml.models.multi_task import build_masked_language_model
|
from ..ml.models.multi_task import build_masked_language_model
|
||||||
from ..tokens import Doc
|
from ..tokens import Doc
|
||||||
from ..attrs import ID, HEAD
|
from ..attrs import ID, HEAD
|
||||||
from ..ml.models.tok2vec import build_Tok2Vec_model
|
|
||||||
from .. import util
|
from .. import util
|
||||||
from ..util import create_default_optimizer
|
from ..gold import Example
|
||||||
from .train import _load_pretrained_tok2vec
|
|
||||||
|
|
||||||
|
|
||||||
def pretrain(
|
@plac.annotations(
|
||||||
# fmt: off
|
# fmt: off
|
||||||
texts_loc: ("Path to JSONL file with raw texts to learn from, with text provided as the key 'text' or tokens as the key 'tokens'", "positional", None, str),
|
texts_loc=("Path to JSONL file with raw texts to learn from, with text provided as the key 'text' or tokens as the key 'tokens'", "positional", None, str),
|
||||||
vectors_model: ("Name or path to spaCy model with vectors to learn from", "positional", None, str),
|
vectors_model=("Name or path to spaCy model with vectors to learn from", "positional", None, str),
|
||||||
output_dir: ("Directory to write models to on each epoch", "positional", None, str),
|
output_dir=("Directory to write models to on each epoch", "positional", None, Path),
|
||||||
width: ("Width of CNN layers", "option", "cw", int) = 96,
|
config_path=("Path to config file", "positional", None, Path),
|
||||||
conv_depth: ("Depth of CNN layers", "option", "cd", int) = 4,
|
use_gpu=("Use GPU", "option", "g", int),
|
||||||
bilstm_depth: ("Depth of BiLSTM layers (requires PyTorch)", "option", "lstm", int) = 0,
|
resume_path=("Path to pretrained weights from which to resume pretraining", "option","r", Path),
|
||||||
cnn_pieces: ("Maxout size for CNN layers. 1 for Mish", "option", "cP", int) = 3,
|
epoch_resume=("The epoch to resume counting from when using '--resume_path'. Prevents unintended overwriting of existing weight files.","option", "er", int),
|
||||||
sa_depth: ("Depth of self-attention layers", "option", "sa", int) = 0,
|
|
||||||
use_chars: ("Whether to use character-based embedding", "flag", "chr", bool) = False,
|
|
||||||
cnn_window: ("Window size for CNN layers", "option", "cW", int) = 1,
|
|
||||||
embed_rows: ("Number of embedding rows", "option", "er", int) = 2000,
|
|
||||||
loss_func: ("Loss function to use for the objective. Either 'L2' or 'cosine'", "option", "L", str) = "cosine",
|
|
||||||
use_vectors: ("Whether to use the static vectors as input features", "flag", "uv") = False,
|
|
||||||
dropout: ("Dropout rate", "option", "d", float) = 0.2,
|
|
||||||
n_iter: ("Number of iterations to pretrain", "option", "i", int) = 1000,
|
|
||||||
batch_size: ("Number of words per training batch", "option", "bs", int) = 3000,
|
|
||||||
max_length: ("Max words per example. Longer examples are discarded", "option", "xw", int) = 500,
|
|
||||||
min_length: ("Min words per example. Shorter examples are discarded", "option", "nw", int) = 5,
|
|
||||||
seed: ("Seed for random number generators", "option", "s", int) = 0,
|
|
||||||
n_save_every: ("Save model every X batches.", "option", "se", int) = None,
|
|
||||||
init_tok2vec: ("Path to pretrained weights for the token-to-vector parts of the models. See 'spacy pretrain'. Experimental.", "option", "t2v", Path) = None,
|
|
||||||
epoch_start: ("The epoch to start counting at. Only relevant when using '--init-tok2vec' and the given weight file has been renamed. Prevents unintended overwriting of existing weight files.", "option", "es", int) = None,
|
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
)
|
||||||
|
def pretrain(
|
||||||
|
texts_loc,
|
||||||
|
vectors_model,
|
||||||
|
config_path,
|
||||||
|
output_dir,
|
||||||
|
use_gpu=-1,
|
||||||
|
resume_path=None,
|
||||||
|
epoch_resume=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Pre-train the 'token-to-vector' (tok2vec) layer of pipeline components,
|
Pre-train the 'token-to-vector' (tok2vec) layer of pipeline components,
|
||||||
|
@ -58,34 +49,46 @@ def pretrain(
|
||||||
However, it's still quite experimental, so your mileage may vary.
|
However, it's still quite experimental, so your mileage may vary.
|
||||||
|
|
||||||
To load the weights back in during 'spacy train', you need to ensure
|
To load the weights back in during 'spacy train', you need to ensure
|
||||||
all settings are the same between pretraining and training. The API and
|
all settings are the same between pretraining and training. Ideally,
|
||||||
errors around this need some improvement.
|
this is done by using the same config file for both commands.
|
||||||
"""
|
"""
|
||||||
config = dict(locals())
|
if not config_path or not config_path.exists():
|
||||||
for key in config:
|
msg.fail("Config file not found", config_path, exits=1)
|
||||||
if isinstance(config[key], Path):
|
|
||||||
config[key] = str(config[key])
|
|
||||||
util.fix_random_seed(seed)
|
|
||||||
|
|
||||||
has_gpu = prefer_gpu()
|
if use_gpu >= 0:
|
||||||
if has_gpu:
|
msg.info("Using GPU")
|
||||||
import torch
|
util.use_gpu(use_gpu)
|
||||||
|
else:
|
||||||
|
msg.info("Using CPU")
|
||||||
|
|
||||||
torch.set_default_tensor_type("torch.cuda.FloatTensor")
|
msg.info(f"Loading config from: {config_path}")
|
||||||
msg.info("Using GPU" if has_gpu else "Not using GPU")
|
config = util.load_config(config_path, create_objects=False)
|
||||||
|
util.fix_random_seed(config["pretraining"]["seed"])
|
||||||
|
if config["pretraining"]["use_pytorch_for_gpu_memory"]:
|
||||||
|
use_pytorch_for_gpu_memory()
|
||||||
|
|
||||||
output_dir = Path(output_dir)
|
|
||||||
if output_dir.exists() and [p for p in output_dir.iterdir()]:
|
if output_dir.exists() and [p for p in output_dir.iterdir()]:
|
||||||
msg.warn(
|
if resume_path:
|
||||||
"Output directory is not empty",
|
msg.warn(
|
||||||
"It is better to use an empty directory or refer to a new output path, "
|
"Output directory is not empty. ",
|
||||||
"then the new directory will be created for you.",
|
"If you're resuming a run from a previous model in this directory, "
|
||||||
)
|
"the old models for the consecutive epochs will be overwritten "
|
||||||
|
"with the new ones.",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
msg.warn(
|
||||||
|
"Output directory is not empty. ",
|
||||||
|
"It is better to use an empty directory or refer to a new output path, "
|
||||||
|
"then the new directory will be created for you.",
|
||||||
|
)
|
||||||
if not output_dir.exists():
|
if not output_dir.exists():
|
||||||
output_dir.mkdir()
|
output_dir.mkdir()
|
||||||
msg.good(f"Created output directory: {output_dir}")
|
msg.good(f"Created output directory: {output_dir}")
|
||||||
srsly.write_json(output_dir / "config.json", config)
|
srsly.write_json(output_dir / "config.json", config)
|
||||||
msg.good("Saved settings to config.json")
|
msg.good("Saved config file in the output directory")
|
||||||
|
|
||||||
|
config = util.load_config(config_path, create_objects=True)
|
||||||
|
pretrain_config = config["pretraining"]
|
||||||
|
|
||||||
# Load texts from file or stdin
|
# Load texts from file or stdin
|
||||||
if texts_loc != "-": # reading from a file
|
if texts_loc != "-": # reading from a file
|
||||||
|
@ -99,57 +102,50 @@ def pretrain(
|
||||||
msg.good("Loaded input texts")
|
msg.good("Loaded input texts")
|
||||||
random.shuffle(texts)
|
random.shuffle(texts)
|
||||||
else: # reading from stdin
|
else: # reading from stdin
|
||||||
msg.text("Reading input text from stdin...")
|
msg.info("Reading input text from stdin...")
|
||||||
texts = srsly.read_jsonl("-")
|
texts = srsly.read_jsonl("-")
|
||||||
|
|
||||||
with msg.loading(f"Loading model '{vectors_model}'..."):
|
with msg.loading(f"Loading model '{vectors_model}'..."):
|
||||||
nlp = util.load_model(vectors_model)
|
nlp = util.load_model(vectors_model)
|
||||||
msg.good(f"Loaded model '{vectors_model}'")
|
msg.good(f"Loaded model '{vectors_model}'")
|
||||||
pretrained_vectors = None if not use_vectors else nlp.vocab.vectors
|
tok2vec_path = pretrain_config["tok2vec_model"]
|
||||||
model = create_pretraining_model(
|
tok2vec = config
|
||||||
nlp,
|
for subpath in tok2vec_path.split("."):
|
||||||
# TODO: replace with config
|
tok2vec = tok2vec.get(subpath)
|
||||||
build_Tok2Vec_model(
|
model = create_pretraining_model(nlp, tok2vec)
|
||||||
width,
|
optimizer = pretrain_config["optimizer"]
|
||||||
embed_rows,
|
|
||||||
conv_depth=conv_depth,
|
# Load in pretrained weights to resume from
|
||||||
pretrained_vectors=pretrained_vectors,
|
if resume_path is not None:
|
||||||
bilstm_depth=bilstm_depth, # Requires PyTorch. Experimental.
|
msg.info(f"Resume training tok2vec from: {resume_path}")
|
||||||
subword_features=not use_chars, # Set to False for Chinese etc
|
with resume_path.open("rb") as file_:
|
||||||
maxout_pieces=cnn_pieces, # If set to 1, use Mish activation.
|
weights_data = file_.read()
|
||||||
window_size=1,
|
model.get_ref("tok2vec").from_bytes(weights_data)
|
||||||
char_embed=False,
|
|
||||||
nM=64,
|
|
||||||
nC=8,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
# Load in pretrained weights
|
|
||||||
if init_tok2vec is not None:
|
|
||||||
components = _load_pretrained_tok2vec(nlp, init_tok2vec)
|
|
||||||
msg.text(f"Loaded pretrained tok2vec for: {components}")
|
|
||||||
# Parse the epoch number from the given weight file
|
# Parse the epoch number from the given weight file
|
||||||
model_name = re.search(r"model\d+\.bin", str(init_tok2vec))
|
model_name = re.search(r"model\d+\.bin", str(resume_path))
|
||||||
if model_name:
|
if model_name:
|
||||||
# Default weight file name so read epoch_start from it by cutting off 'model' and '.bin'
|
# Default weight file name so read epoch_start from it by cutting off 'model' and '.bin'
|
||||||
epoch_start = int(model_name.group(0)[5:][:-4]) + 1
|
epoch_resume = int(model_name.group(0)[5:][:-4]) + 1
|
||||||
|
msg.info(f"Resuming from epoch: {epoch_resume}")
|
||||||
else:
|
else:
|
||||||
if not epoch_start:
|
if not epoch_resume:
|
||||||
msg.fail(
|
msg.fail(
|
||||||
"You have to use the --epoch-start argument when using a renamed weight file for --init-tok2vec",
|
"You have to use the --epoch-resume setting when using a renamed weight file for --resume-path",
|
||||||
exits=True,
|
exits=True,
|
||||||
)
|
)
|
||||||
elif epoch_start < 0:
|
elif epoch_resume < 0:
|
||||||
msg.fail(
|
msg.fail(
|
||||||
f"The argument --epoch-start has to be greater or equal to 0. {epoch_start} is invalid",
|
f"The argument --epoch-resume has to be greater or equal to 0. {epoch_resume} is invalid",
|
||||||
exits=True,
|
exits=True,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
msg.info(f"Resuming from epoch: {epoch_resume}")
|
||||||
else:
|
else:
|
||||||
# Without '--init-tok2vec' the '--epoch-start' argument is ignored
|
# Without '--resume-path' the '--epoch-resume' argument is ignored
|
||||||
epoch_start = 0
|
epoch_resume = 0
|
||||||
|
|
||||||
optimizer = create_default_optimizer()
|
|
||||||
tracker = ProgressTracker(frequency=10000)
|
tracker = ProgressTracker(frequency=10000)
|
||||||
msg.divider(f"Pre-training tok2vec layer - starting at epoch {epoch_start}")
|
msg.divider(f"Pre-training tok2vec layer - starting at epoch {epoch_resume}")
|
||||||
row_settings = {"widths": (3, 10, 10, 6, 4), "aligns": ("r", "r", "r", "r", "r")}
|
row_settings = {"widths": (3, 10, 10, 6, 4), "aligns": ("r", "r", "r", "r", "r")}
|
||||||
msg.row(("#", "# Words", "Total Loss", "Loss", "w/s"), **row_settings)
|
msg.row(("#", "# Words", "Total Loss", "Loss", "w/s"), **row_settings)
|
||||||
|
|
||||||
|
@ -168,28 +164,27 @@ def pretrain(
|
||||||
file_.write(srsly.json_dumps(log) + "\n")
|
file_.write(srsly.json_dumps(log) + "\n")
|
||||||
|
|
||||||
skip_counter = 0
|
skip_counter = 0
|
||||||
for epoch in range(epoch_start, n_iter + epoch_start):
|
loss_func = pretrain_config["loss_func"]
|
||||||
for batch_id, batch in enumerate(
|
for epoch in range(epoch_resume, pretrain_config["max_epochs"]):
|
||||||
util.minibatch_by_words(
|
examples = [Example(doc=text) for text in texts]
|
||||||
(Example(doc=text) for text in texts), size=batch_size
|
batches = util.minibatch_by_words(examples, size=pretrain_config["batch_size"])
|
||||||
)
|
for batch_id, batch in enumerate(batches):
|
||||||
):
|
|
||||||
docs, count = make_docs(
|
docs, count = make_docs(
|
||||||
nlp,
|
nlp,
|
||||||
[text for (text, _) in batch],
|
[ex.doc for ex in batch],
|
||||||
max_length=max_length,
|
max_length=pretrain_config["max_length"],
|
||||||
min_length=min_length,
|
min_length=pretrain_config["min_length"],
|
||||||
)
|
)
|
||||||
skip_counter += count
|
skip_counter += count
|
||||||
loss = make_update(
|
loss = make_update(model, docs, optimizer, distance=loss_func)
|
||||||
model, docs, optimizer, objective=loss_func, drop=dropout
|
|
||||||
)
|
|
||||||
progress = tracker.update(epoch, loss, docs)
|
progress = tracker.update(epoch, loss, docs)
|
||||||
if progress:
|
if progress:
|
||||||
msg.row(progress, **row_settings)
|
msg.row(progress, **row_settings)
|
||||||
if texts_loc == "-" and tracker.words_per_epoch[epoch] >= 10 ** 7:
|
if texts_loc == "-" and tracker.words_per_epoch[epoch] >= 10 ** 7:
|
||||||
break
|
break
|
||||||
if n_save_every and (batch_id % n_save_every == 0):
|
if pretrain_config["n_save_every"] and (
|
||||||
|
batch_id % pretrain_config["n_save_every"] == 0
|
||||||
|
):
|
||||||
_save_model(epoch, is_temp=True)
|
_save_model(epoch, is_temp=True)
|
||||||
_save_model(epoch)
|
_save_model(epoch)
|
||||||
tracker.epoch_loss = 0.0
|
tracker.epoch_loss = 0.0
|
||||||
|
@ -201,17 +196,17 @@ def pretrain(
|
||||||
msg.good("Successfully finished pretrain")
|
msg.good("Successfully finished pretrain")
|
||||||
|
|
||||||
|
|
||||||
def make_update(model, docs, optimizer, drop=0.0, objective="L2"):
|
def make_update(model, docs, optimizer, distance):
|
||||||
"""Perform an update over a single batch of documents.
|
"""Perform an update over a single batch of documents.
|
||||||
|
|
||||||
docs (iterable): A batch of `Doc` objects.
|
docs (iterable): A batch of `Doc` objects.
|
||||||
drop (float): The dropout rate.
|
|
||||||
optimizer (callable): An optimizer.
|
optimizer (callable): An optimizer.
|
||||||
RETURNS loss: A float for the loss.
|
RETURNS loss: A float for the loss.
|
||||||
"""
|
"""
|
||||||
predictions, backprop = model.begin_update(docs, drop=drop)
|
predictions, backprop = model.begin_update(docs)
|
||||||
loss, gradients = get_vectors_loss(model.ops, docs, predictions, objective)
|
loss, gradients = get_vectors_loss(model.ops, docs, predictions, distance)
|
||||||
backprop(gradients, sgd=optimizer)
|
backprop(gradients)
|
||||||
|
model.finish_update(optimizer)
|
||||||
# Don't want to return a cupy object here
|
# Don't want to return a cupy object here
|
||||||
# The gradients are modified in-place by the BERT MLM,
|
# The gradients are modified in-place by the BERT MLM,
|
||||||
# so we get an accurate loss
|
# so we get an accurate loss
|
||||||
|
@ -243,12 +238,12 @@ def make_docs(nlp, batch, min_length, max_length):
|
||||||
heads = numpy.asarray(heads, dtype="uint64")
|
heads = numpy.asarray(heads, dtype="uint64")
|
||||||
heads = heads.reshape((len(doc), 1))
|
heads = heads.reshape((len(doc), 1))
|
||||||
doc = doc.from_array([HEAD], heads)
|
doc = doc.from_array([HEAD], heads)
|
||||||
if len(doc) >= min_length and len(doc) < max_length:
|
if min_length <= len(doc) < max_length:
|
||||||
docs.append(doc)
|
docs.append(doc)
|
||||||
return docs, skip_count
|
return docs, skip_count
|
||||||
|
|
||||||
|
|
||||||
def get_vectors_loss(ops, docs, prediction, objective="L2"):
|
def get_vectors_loss(ops, docs, prediction, distance):
|
||||||
"""Compute a mean-squared error loss between the documents' vectors and
|
"""Compute a mean-squared error loss between the documents' vectors and
|
||||||
the prediction.
|
the prediction.
|
||||||
|
|
||||||
|
@ -262,13 +257,6 @@ def get_vectors_loss(ops, docs, prediction, objective="L2"):
|
||||||
# and look them up all at once. This prevents data copying.
|
# and look them up all at once. This prevents data copying.
|
||||||
ids = ops.flatten([doc.to_array(ID).ravel() for doc in docs])
|
ids = ops.flatten([doc.to_array(ID).ravel() for doc in docs])
|
||||||
target = docs[0].vocab.vectors.data[ids]
|
target = docs[0].vocab.vectors.data[ids]
|
||||||
# TODO: this code originally didn't normalize, but shouldn't normalize=True ?
|
|
||||||
if objective == "L2":
|
|
||||||
distance = L2Distance(normalize=False)
|
|
||||||
elif objective == "cosine":
|
|
||||||
distance = CosineDistance(normalize=False)
|
|
||||||
else:
|
|
||||||
raise ValueError(Errors.E142.format(loss_func=objective))
|
|
||||||
d_target, loss = distance(prediction, target)
|
d_target, loss = distance(prediction, target)
|
||||||
return loss, d_target
|
return loss, d_target
|
||||||
|
|
||||||
|
@ -281,7 +269,7 @@ def create_pretraining_model(nlp, tok2vec):
|
||||||
"""
|
"""
|
||||||
output_size = nlp.vocab.vectors.data.shape[1]
|
output_size = nlp.vocab.vectors.data.shape[1]
|
||||||
output_layer = chain(
|
output_layer = chain(
|
||||||
Maxout(300, pieces=3, normalize=True, dropout=0.0), Linear(output_size)
|
Maxout(nO=300, nP=3, normalize=True, dropout=0.0), Linear(output_size)
|
||||||
)
|
)
|
||||||
# This is annoying, but the parser etc have the flatten step after
|
# This is annoying, but the parser etc have the flatten step after
|
||||||
# the tok2vec. To load the weights in cleanly, we need to match
|
# the tok2vec. To load the weights in cleanly, we need to match
|
||||||
|
@ -289,11 +277,12 @@ def create_pretraining_model(nlp, tok2vec):
|
||||||
# "tok2vec" has to be the same set of processes as what the components do.
|
# "tok2vec" has to be the same set of processes as what the components do.
|
||||||
tok2vec = chain(tok2vec, list2array())
|
tok2vec = chain(tok2vec, list2array())
|
||||||
model = chain(tok2vec, output_layer)
|
model = chain(tok2vec, output_layer)
|
||||||
model = build_masked_language_model(nlp.vocab, model)
|
|
||||||
model.set_ref("tok2vec", tok2vec)
|
|
||||||
model.set_ref("output_layer", output_layer)
|
|
||||||
model.initialize(X=[nlp.make_doc("Give it a doc to infer shapes")])
|
model.initialize(X=[nlp.make_doc("Give it a doc to infer shapes")])
|
||||||
return model
|
mlm_model = build_masked_language_model(nlp.vocab, model)
|
||||||
|
mlm_model.set_ref("tok2vec", tok2vec)
|
||||||
|
mlm_model.set_ref("output_layer", output_layer)
|
||||||
|
mlm_model.initialize(X=[nlp.make_doc("Give it a doc to infer shapes")])
|
||||||
|
return mlm_model
|
||||||
|
|
||||||
|
|
||||||
class ProgressTracker(object):
|
class ProgressTracker(object):
|
||||||
|
|
|
@ -13,6 +13,7 @@ import random
|
||||||
from ..gold import GoldCorpus
|
from ..gold import GoldCorpus
|
||||||
from .. import util
|
from .. import util
|
||||||
from ..errors import Errors
|
from ..errors import Errors
|
||||||
|
from ..ml import models # don't remove - required to load the built-in architectures
|
||||||
|
|
||||||
registry = util.registry
|
registry = util.registry
|
||||||
|
|
||||||
|
@ -123,7 +124,7 @@ class ConfigSchema(BaseModel):
|
||||||
use_gpu=("Use GPU", "option", "g", int),
|
use_gpu=("Use GPU", "option", "g", int),
|
||||||
# fmt: on
|
# fmt: on
|
||||||
)
|
)
|
||||||
def train_from_config_cli(
|
def train_cli(
|
||||||
train_path,
|
train_path,
|
||||||
dev_path,
|
dev_path,
|
||||||
config_path,
|
config_path,
|
||||||
|
@ -132,7 +133,7 @@ def train_from_config_cli(
|
||||||
raw_text=None,
|
raw_text=None,
|
||||||
debug=False,
|
debug=False,
|
||||||
verbose=False,
|
verbose=False,
|
||||||
use_gpu=-1
|
use_gpu=-1,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Train or update a spaCy model. Requires data to be formatted in spaCy's
|
Train or update a spaCy model. Requires data to be formatted in spaCy's
|
||||||
|
@ -156,7 +157,7 @@ def train_from_config_cli(
|
||||||
else:
|
else:
|
||||||
msg.info("Using CPU")
|
msg.info("Using CPU")
|
||||||
|
|
||||||
train_from_config(
|
train(
|
||||||
config_path,
|
config_path,
|
||||||
{"train": train_path, "dev": dev_path},
|
{"train": train_path, "dev": dev_path},
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
|
@ -165,10 +166,11 @@ def train_from_config_cli(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def train_from_config(
|
def train(
|
||||||
config_path, data_paths, raw_text=None, meta_path=None, output_path=None,
|
config_path, data_paths, raw_text=None, meta_path=None, output_path=None,
|
||||||
):
|
):
|
||||||
msg.info(f"Loading config from: {config_path}")
|
msg.info(f"Loading config from: {config_path}")
|
||||||
|
# Read the config first without creating objects, to get to the original nlp_config
|
||||||
config = util.load_config(config_path, create_objects=False)
|
config = util.load_config(config_path, create_objects=False)
|
||||||
util.fix_random_seed(config["training"]["seed"])
|
util.fix_random_seed(config["training"]["seed"])
|
||||||
if config["training"]["use_pytorch_for_gpu_memory"]:
|
if config["training"]["use_pytorch_for_gpu_memory"]:
|
||||||
|
@ -177,8 +179,8 @@ def train_from_config(
|
||||||
config = util.load_config(config_path, create_objects=True)
|
config = util.load_config(config_path, create_objects=True)
|
||||||
msg.info("Creating nlp from config")
|
msg.info("Creating nlp from config")
|
||||||
nlp = util.load_model_from_config(nlp_config)
|
nlp = util.load_model_from_config(nlp_config)
|
||||||
optimizer = config["optimizer"]
|
|
||||||
training = config["training"]
|
training = config["training"]
|
||||||
|
optimizer = training["optimizer"]
|
||||||
limit = training["limit"]
|
limit = training["limit"]
|
||||||
msg.info("Loading training corpus")
|
msg.info("Loading training corpus")
|
||||||
corpus = GoldCorpus(data_paths["train"], data_paths["dev"], limit=limit)
|
corpus = GoldCorpus(data_paths["train"], data_paths["dev"], limit=limit)
|
||||||
|
@ -246,13 +248,19 @@ def create_train_batches(nlp, corpus, cfg):
|
||||||
if len(train_examples) == 0:
|
if len(train_examples) == 0:
|
||||||
raise ValueError(Errors.E988)
|
raise ValueError(Errors.E988)
|
||||||
random.shuffle(train_examples)
|
random.shuffle(train_examples)
|
||||||
batches = util.minibatch_by_words(train_examples, size=cfg["batch_size"])
|
batches = util.minibatch_by_words(train_examples, size=cfg["batch_size"], discard_oversize=cfg["discard_oversize"])
|
||||||
|
# make sure the minibatch_by_words result is not empty, or we'll have an infinite training loop
|
||||||
|
try:
|
||||||
|
first = next(batches)
|
||||||
|
yield first
|
||||||
|
except StopIteration:
|
||||||
|
raise ValueError(Errors.E986)
|
||||||
for batch in batches:
|
for batch in batches:
|
||||||
yield batch
|
yield batch
|
||||||
epochs_todo -= 1
|
epochs_todo -= 1
|
||||||
# We intentionally compare exactly to 0 here, so that max_epochs < 1
|
# We intentionally compare exactly to 0 here, so that max_epochs < 1
|
||||||
# will not break.
|
# will not break.
|
||||||
if epochs_todo == 0:
|
if epochs_todo == 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -453,8 +453,6 @@ class Errors(object):
|
||||||
"should be of equal length.")
|
"should be of equal length.")
|
||||||
E141 = ("Entity vectors should be of length {required} instead of the "
|
E141 = ("Entity vectors should be of length {required} instead of the "
|
||||||
"provided {found}.")
|
"provided {found}.")
|
||||||
E142 = ("Unsupported loss_function '{loss_func}'. Use either 'L2' or "
|
|
||||||
"'cosine'.")
|
|
||||||
E143 = ("Labels for component '{name}' not initialized. Did you forget to "
|
E143 = ("Labels for component '{name}' not initialized. Did you forget to "
|
||||||
"call add_label()?")
|
"call add_label()?")
|
||||||
E144 = ("Could not find parameter `{param}` when building the entity "
|
E144 = ("Could not find parameter `{param}` when building the entity "
|
||||||
|
@ -577,6 +575,8 @@ class Errors(object):
|
||||||
|
|
||||||
# TODO: fix numbering after merging develop into master
|
# TODO: fix numbering after merging develop into master
|
||||||
|
|
||||||
|
E986 = ("Could not create any training batches: check your input. "
|
||||||
|
"Perhaps discard_oversize should be set to False ?")
|
||||||
E987 = ("The text of an example training instance is either a Doc or "
|
E987 = ("The text of an example training instance is either a Doc or "
|
||||||
"a string, but found {type} instead.")
|
"a string, but found {type} instead.")
|
||||||
E988 = ("Could not parse any training examples. Ensure the data is "
|
E988 = ("Could not parse any training examples. Ensure the data is "
|
||||||
|
|
|
@ -231,10 +231,6 @@ class Language(object):
|
||||||
|
|
||||||
# Conveniences to access pipeline components
|
# Conveniences to access pipeline components
|
||||||
# Shouldn't be used anymore!
|
# Shouldn't be used anymore!
|
||||||
@property
|
|
||||||
def tensorizer(self):
|
|
||||||
return self.get_pipe("tensorizer")
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def tagger(self):
|
def tagger(self):
|
||||||
return self.get_pipe("tagger")
|
return self.get_pipe("tagger")
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
from .models import *
|
|
@ -2,6 +2,5 @@ from .entity_linker import * # noqa
|
||||||
from .parser import * # noqa
|
from .parser import * # noqa
|
||||||
from .simple_ner import *
|
from .simple_ner import *
|
||||||
from .tagger import * # noqa
|
from .tagger import * # noqa
|
||||||
from .tensorizer import * # noqa
|
|
||||||
from .textcat import * # noqa
|
from .textcat import * # noqa
|
||||||
from .tok2vec import * # noqa
|
from .tok2vec import * # noqa
|
||||||
|
|
|
@ -1,4 +1,6 @@
|
||||||
from thinc.api import chain, Maxout, LayerNorm, Softmax, Linear, zero_init
|
import numpy
|
||||||
|
|
||||||
|
from thinc.api import chain, Maxout, LayerNorm, Softmax, Linear, zero_init, Model
|
||||||
|
|
||||||
|
|
||||||
def build_multi_task_model(n_tags, tok2vec=None, token_vector_width=96):
|
def build_multi_task_model(n_tags, tok2vec=None, token_vector_width=96):
|
||||||
|
@ -24,6 +26,80 @@ def build_cloze_multi_task_model(vocab, tok2vec):
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def build_masked_language_model(*args, **kwargs):
|
def build_masked_language_model(vocab, wrapped_model, mask_prob=0.15):
|
||||||
# TODO cf https://github.com/explosion/spaCy/blob/2c107f02a4d60bda2440db0aad1a88cbbf4fb52d/spacy/_ml.py#L828
|
"""Convert a model into a BERT-style masked language model"""
|
||||||
raise NotImplementedError
|
|
||||||
|
random_words = _RandomWords(vocab)
|
||||||
|
|
||||||
|
def mlm_forward(model, docs, is_train):
|
||||||
|
mask, docs = _apply_mask(docs, random_words, mask_prob=mask_prob)
|
||||||
|
mask = model.ops.asarray(mask).reshape((mask.shape[0], 1))
|
||||||
|
output, backprop = model.get_ref("wrapped-model").begin_update(docs) # drop=drop
|
||||||
|
|
||||||
|
def mlm_backward(d_output):
|
||||||
|
d_output *= 1 - mask
|
||||||
|
return backprop(d_output)
|
||||||
|
|
||||||
|
return output, mlm_backward
|
||||||
|
|
||||||
|
mlm_model = Model("masked-language-model", mlm_forward, layers=[wrapped_model])
|
||||||
|
mlm_model.set_ref("wrapped-model", wrapped_model)
|
||||||
|
|
||||||
|
return mlm_model
|
||||||
|
|
||||||
|
|
||||||
|
class _RandomWords(object):
|
||||||
|
def __init__(self, vocab):
|
||||||
|
self.words = [lex.text for lex in vocab if lex.prob != 0.0]
|
||||||
|
self.probs = [lex.prob for lex in vocab if lex.prob != 0.0]
|
||||||
|
self.words = self.words[:10000]
|
||||||
|
self.probs = self.probs[:10000]
|
||||||
|
self.probs = numpy.exp(numpy.array(self.probs, dtype="f"))
|
||||||
|
self.probs /= self.probs.sum()
|
||||||
|
self._cache = []
|
||||||
|
|
||||||
|
def next(self):
|
||||||
|
if not self._cache:
|
||||||
|
self._cache.extend(
|
||||||
|
numpy.random.choice(len(self.words), 10000, p=self.probs)
|
||||||
|
)
|
||||||
|
index = self._cache.pop()
|
||||||
|
return self.words[index]
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_mask(docs, random_words, mask_prob=0.15):
|
||||||
|
# This needs to be here to avoid circular imports
|
||||||
|
from ...tokens import Doc
|
||||||
|
|
||||||
|
N = sum(len(doc) for doc in docs)
|
||||||
|
mask = numpy.random.uniform(0.0, 1.0, (N,))
|
||||||
|
mask = mask >= mask_prob
|
||||||
|
i = 0
|
||||||
|
masked_docs = []
|
||||||
|
for doc in docs:
|
||||||
|
words = []
|
||||||
|
for token in doc:
|
||||||
|
if not mask[i]:
|
||||||
|
word = _replace_word(token.text, random_words)
|
||||||
|
else:
|
||||||
|
word = token.text
|
||||||
|
words.append(word)
|
||||||
|
i += 1
|
||||||
|
spaces = [bool(w.whitespace_) for w in doc]
|
||||||
|
# NB: If you change this implementation to instead modify
|
||||||
|
# the docs in place, take care that the IDs reflect the original
|
||||||
|
# words. Currently we use the original docs to make the vectors
|
||||||
|
# for the target, so we don't lose the original tokens. But if
|
||||||
|
# you modified the docs in place here, you would.
|
||||||
|
masked_docs.append(Doc(doc.vocab, words=words, spaces=spaces))
|
||||||
|
return mask, masked_docs
|
||||||
|
|
||||||
|
|
||||||
|
def _replace_word(word, random_words, mask="[MASK]"):
|
||||||
|
roll = numpy.random.random()
|
||||||
|
if roll < 0.8:
|
||||||
|
return mask
|
||||||
|
elif roll < 0.9:
|
||||||
|
return random_words.next()
|
||||||
|
else:
|
||||||
|
return word
|
||||||
|
|
|
@ -1,10 +0,0 @@
|
||||||
from thinc.api import Linear, zero_init
|
|
||||||
|
|
||||||
from ... import util
|
|
||||||
from ...util import registry
|
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures.register("spacy.Tensorizer.v1")
|
|
||||||
def build_tensorizer(input_size, output_size):
|
|
||||||
input_size = util.env_opt("token_vector_width", input_size)
|
|
||||||
return Linear(output_size, input_size, init_W=zero_init)
|
|
|
@ -49,13 +49,13 @@ def build_bow_text_classifier(exclusive_classes, ngram_size, no_output_layer, nO
|
||||||
|
|
||||||
@registry.architectures.register("spacy.TextCat.v1")
|
@registry.architectures.register("spacy.TextCat.v1")
|
||||||
def build_text_classifier(width, embed_size, pretrained_vectors, exclusive_classes, ngram_size,
|
def build_text_classifier(width, embed_size, pretrained_vectors, exclusive_classes, ngram_size,
|
||||||
window_size, conv_depth, nO=None):
|
window_size, conv_depth, dropout, nO=None):
|
||||||
cols = [ORTH, LOWER, PREFIX, SUFFIX, SHAPE, ID]
|
cols = [ORTH, LOWER, PREFIX, SUFFIX, SHAPE, ID]
|
||||||
with Model.define_operators({">>": chain, "|": concatenate, "**": clone}):
|
with Model.define_operators({">>": chain, "|": concatenate, "**": clone}):
|
||||||
lower = HashEmbed(nO=width, nV=embed_size, column=cols.index(LOWER))
|
lower = HashEmbed(nO=width, nV=embed_size, column=cols.index(LOWER), dropout=dropout)
|
||||||
prefix = HashEmbed(nO=width // 2, nV=embed_size, column=cols.index(PREFIX))
|
prefix = HashEmbed(nO=width // 2, nV=embed_size, column=cols.index(PREFIX), dropout=dropout)
|
||||||
suffix = HashEmbed(nO=width // 2, nV=embed_size, column=cols.index(SUFFIX))
|
suffix = HashEmbed(nO=width // 2, nV=embed_size, column=cols.index(SUFFIX), dropout=dropout)
|
||||||
shape = HashEmbed(nO=width // 2, nV=embed_size, column=cols.index(SHAPE))
|
shape = HashEmbed(nO=width // 2, nV=embed_size, column=cols.index(SHAPE), dropout=dropout)
|
||||||
|
|
||||||
width_nI = sum(layer.get_dim("nO") for layer in [lower, prefix, suffix, shape])
|
width_nI = sum(layer.get_dim("nO") for layer in [lower, prefix, suffix, shape])
|
||||||
trained_vectors = FeatureExtractor(cols) >> with_array(
|
trained_vectors = FeatureExtractor(cols) >> with_array(
|
||||||
|
@ -114,7 +114,7 @@ def build_text_classifier(width, embed_size, pretrained_vectors, exclusive_class
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures.register("spacy.TextCatLowData.v1")
|
@registry.architectures.register("spacy.TextCatLowData.v1")
|
||||||
def build_text_classifier_lowdata(width, pretrained_vectors, nO=None):
|
def build_text_classifier_lowdata(width, pretrained_vectors, dropout, nO=None):
|
||||||
nlp = util.load_model(pretrained_vectors)
|
nlp = util.load_model(pretrained_vectors)
|
||||||
vectors = nlp.vocab.vectors
|
vectors = nlp.vocab.vectors
|
||||||
vector_dim = vectors.data.shape[1]
|
vector_dim = vectors.data.shape[1]
|
||||||
|
@ -129,7 +129,8 @@ def build_text_classifier_lowdata(width, pretrained_vectors, nO=None):
|
||||||
>> reduce_sum()
|
>> reduce_sum()
|
||||||
>> residual(Relu(width, width)) ** 2
|
>> residual(Relu(width, width)) ** 2
|
||||||
>> Linear(nO, width)
|
>> Linear(nO, width)
|
||||||
>> Dropout(0.0)
|
|
||||||
>> Logistic()
|
|
||||||
)
|
)
|
||||||
|
if dropout:
|
||||||
|
model = model >> Dropout(dropout)
|
||||||
|
model = model >> Logistic()
|
||||||
return model
|
return model
|
||||||
|
|
|
@ -49,6 +49,7 @@ def hash_embed_cnn(
|
||||||
maxout_pieces,
|
maxout_pieces,
|
||||||
window_size,
|
window_size,
|
||||||
subword_features,
|
subword_features,
|
||||||
|
dropout,
|
||||||
):
|
):
|
||||||
# Does not use character embeddings: set to False by default
|
# Does not use character embeddings: set to False by default
|
||||||
return build_Tok2Vec_model(
|
return build_Tok2Vec_model(
|
||||||
|
@ -63,6 +64,7 @@ def hash_embed_cnn(
|
||||||
char_embed=False,
|
char_embed=False,
|
||||||
nM=0,
|
nM=0,
|
||||||
nC=0,
|
nC=0,
|
||||||
|
dropout=dropout,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -76,6 +78,7 @@ def hash_charembed_cnn(
|
||||||
window_size,
|
window_size,
|
||||||
nM,
|
nM,
|
||||||
nC,
|
nC,
|
||||||
|
dropout,
|
||||||
):
|
):
|
||||||
# Allows using character embeddings by setting nC, nM and char_embed=True
|
# Allows using character embeddings by setting nC, nM and char_embed=True
|
||||||
return build_Tok2Vec_model(
|
return build_Tok2Vec_model(
|
||||||
|
@ -90,12 +93,13 @@ def hash_charembed_cnn(
|
||||||
char_embed=True,
|
char_embed=True,
|
||||||
nM=nM,
|
nM=nM,
|
||||||
nC=nC,
|
nC=nC,
|
||||||
|
dropout=dropout,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures.register("spacy.HashEmbedBiLSTM.v1")
|
@registry.architectures.register("spacy.HashEmbedBiLSTM.v1")
|
||||||
def hash_embed_bilstm_v1(
|
def hash_embed_bilstm_v1(
|
||||||
pretrained_vectors, width, depth, embed_size, subword_features, maxout_pieces
|
pretrained_vectors, width, depth, embed_size, subword_features, maxout_pieces, dropout
|
||||||
):
|
):
|
||||||
# Does not use character embeddings: set to False by default
|
# Does not use character embeddings: set to False by default
|
||||||
return build_Tok2Vec_model(
|
return build_Tok2Vec_model(
|
||||||
|
@ -110,12 +114,13 @@ def hash_embed_bilstm_v1(
|
||||||
char_embed=False,
|
char_embed=False,
|
||||||
nM=0,
|
nM=0,
|
||||||
nC=0,
|
nC=0,
|
||||||
|
dropout=dropout,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures.register("spacy.HashCharEmbedBiLSTM.v1")
|
@registry.architectures.register("spacy.HashCharEmbedBiLSTM.v1")
|
||||||
def hash_char_embed_bilstm_v1(
|
def hash_char_embed_bilstm_v1(
|
||||||
pretrained_vectors, width, depth, embed_size, maxout_pieces, nM, nC
|
pretrained_vectors, width, depth, embed_size, maxout_pieces, nM, nC, dropout
|
||||||
):
|
):
|
||||||
# Allows using character embeddings by setting nC, nM and char_embed=True
|
# Allows using character embeddings by setting nC, nM and char_embed=True
|
||||||
return build_Tok2Vec_model(
|
return build_Tok2Vec_model(
|
||||||
|
@ -130,6 +135,7 @@ def hash_char_embed_bilstm_v1(
|
||||||
char_embed=True,
|
char_embed=True,
|
||||||
nM=nM,
|
nM=nM,
|
||||||
nC=nC,
|
nC=nC,
|
||||||
|
dropout=dropout,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -144,19 +150,19 @@ def LayerNormalizedMaxout(width, maxout_pieces):
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures.register("spacy.MultiHashEmbed.v1")
|
@registry.architectures.register("spacy.MultiHashEmbed.v1")
|
||||||
def MultiHashEmbed(columns, width, rows, use_subwords, pretrained_vectors, mix):
|
def MultiHashEmbed(columns, width, rows, use_subwords, pretrained_vectors, mix, dropout):
|
||||||
norm = HashEmbed(nO=width, nV=rows, column=columns.index("NORM"))
|
norm = HashEmbed(nO=width, nV=rows, column=columns.index("NORM"), dropout=dropout)
|
||||||
if use_subwords:
|
if use_subwords:
|
||||||
prefix = HashEmbed(nO=width, nV=rows // 2, column=columns.index("PREFIX"))
|
prefix = HashEmbed(nO=width, nV=rows // 2, column=columns.index("PREFIX"), dropout=dropout)
|
||||||
suffix = HashEmbed(nO=width, nV=rows // 2, column=columns.index("SUFFIX"))
|
suffix = HashEmbed(nO=width, nV=rows // 2, column=columns.index("SUFFIX"), dropout=dropout)
|
||||||
shape = HashEmbed(nO=width, nV=rows // 2, column=columns.index("SHAPE"))
|
shape = HashEmbed(nO=width, nV=rows // 2, column=columns.index("SHAPE"), dropout=dropout)
|
||||||
|
|
||||||
if pretrained_vectors:
|
if pretrained_vectors:
|
||||||
glove = StaticVectors(
|
glove = StaticVectors(
|
||||||
vectors=pretrained_vectors.data,
|
vectors=pretrained_vectors.data,
|
||||||
nO=width,
|
nO=width,
|
||||||
column=columns.index(ID),
|
column=columns.index(ID),
|
||||||
dropout=0.0,
|
dropout=dropout,
|
||||||
)
|
)
|
||||||
|
|
||||||
with Model.define_operators({">>": chain, "|": concatenate}):
|
with Model.define_operators({">>": chain, "|": concatenate}):
|
||||||
|
@ -164,13 +170,10 @@ def MultiHashEmbed(columns, width, rows, use_subwords, pretrained_vectors, mix):
|
||||||
embed_layer = norm
|
embed_layer = norm
|
||||||
else:
|
else:
|
||||||
if use_subwords and pretrained_vectors:
|
if use_subwords and pretrained_vectors:
|
||||||
nr_columns = 5
|
|
||||||
concat_columns = glove | norm | prefix | suffix | shape
|
concat_columns = glove | norm | prefix | suffix | shape
|
||||||
elif use_subwords:
|
elif use_subwords:
|
||||||
nr_columns = 4
|
|
||||||
concat_columns = norm | prefix | suffix | shape
|
concat_columns = norm | prefix | suffix | shape
|
||||||
else:
|
else:
|
||||||
nr_columns = 2
|
|
||||||
concat_columns = glove | norm
|
concat_columns = glove | norm
|
||||||
|
|
||||||
embed_layer = uniqued(concat_columns >> mix, column=columns.index("ORTH"))
|
embed_layer = uniqued(concat_columns >> mix, column=columns.index("ORTH"))
|
||||||
|
@ -179,8 +182,8 @@ def MultiHashEmbed(columns, width, rows, use_subwords, pretrained_vectors, mix):
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures.register("spacy.CharacterEmbed.v1")
|
@registry.architectures.register("spacy.CharacterEmbed.v1")
|
||||||
def CharacterEmbed(columns, width, rows, nM, nC, features):
|
def CharacterEmbed(columns, width, rows, nM, nC, features, dropout):
|
||||||
norm = HashEmbed(nO=width, nV=rows, column=columns.index("NORM"))
|
norm = HashEmbed(nO=width, nV=rows, column=columns.index("NORM"), dropout=dropout)
|
||||||
chr_embed = _character_embed.CharacterEmbed(nM=nM, nC=nC)
|
chr_embed = _character_embed.CharacterEmbed(nM=nM, nC=nC)
|
||||||
with Model.define_operators({">>": chain, "|": concatenate}):
|
with Model.define_operators({">>": chain, "|": concatenate}):
|
||||||
embed_layer = chr_embed | features >> with_array(norm)
|
embed_layer = chr_embed | features >> with_array(norm)
|
||||||
|
@ -238,16 +241,17 @@ def build_Tok2Vec_model(
|
||||||
nC,
|
nC,
|
||||||
conv_depth,
|
conv_depth,
|
||||||
bilstm_depth,
|
bilstm_depth,
|
||||||
|
dropout,
|
||||||
) -> Model:
|
) -> Model:
|
||||||
if char_embed:
|
if char_embed:
|
||||||
subword_features = False
|
subword_features = False
|
||||||
cols = [ID, NORM, PREFIX, SUFFIX, SHAPE, ORTH]
|
cols = [ID, NORM, PREFIX, SUFFIX, SHAPE, ORTH]
|
||||||
with Model.define_operators({">>": chain, "|": concatenate, "**": clone}):
|
with Model.define_operators({">>": chain, "|": concatenate, "**": clone}):
|
||||||
norm = HashEmbed(nO=width, nV=embed_size, column=cols.index(NORM))
|
norm = HashEmbed(nO=width, nV=embed_size, column=cols.index(NORM), dropout=dropout)
|
||||||
if subword_features:
|
if subword_features:
|
||||||
prefix = HashEmbed(nO=width, nV=embed_size // 2, column=cols.index(PREFIX))
|
prefix = HashEmbed(nO=width, nV=embed_size // 2, column=cols.index(PREFIX), dropout=dropout)
|
||||||
suffix = HashEmbed(nO=width, nV=embed_size // 2, column=cols.index(SUFFIX))
|
suffix = HashEmbed(nO=width, nV=embed_size // 2, column=cols.index(SUFFIX), dropout=dropout)
|
||||||
shape = HashEmbed(nO=width, nV=embed_size // 2, column=cols.index(SHAPE))
|
shape = HashEmbed(nO=width, nV=embed_size // 2, column=cols.index(SHAPE), dropout=dropout)
|
||||||
else:
|
else:
|
||||||
prefix, suffix, shape = (None, None, None)
|
prefix, suffix, shape = (None, None, None)
|
||||||
if pretrained_vectors is not None:
|
if pretrained_vectors is not None:
|
||||||
|
@ -255,7 +259,7 @@ def build_Tok2Vec_model(
|
||||||
vectors=pretrained_vectors.data,
|
vectors=pretrained_vectors.data,
|
||||||
nO=width,
|
nO=width,
|
||||||
column=cols.index(ID),
|
column=cols.index(ID),
|
||||||
dropout=0.0,
|
dropout=dropout,
|
||||||
)
|
)
|
||||||
|
|
||||||
if subword_features:
|
if subword_features:
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from .pipes import Tagger, DependencyParser, EntityRecognizer, EntityLinker
|
from .pipes import Tagger, DependencyParser, EntityRecognizer, EntityLinker
|
||||||
from .pipes import TextCategorizer, Tensorizer, Pipe, Sentencizer
|
from .pipes import TextCategorizer, Pipe, Sentencizer
|
||||||
from .pipes import SentenceRecognizer
|
from .pipes import SentenceRecognizer
|
||||||
from .simple_ner import SimpleNER
|
from .simple_ner import SimpleNER
|
||||||
from .morphologizer import Morphologizer
|
from .morphologizer import Morphologizer
|
||||||
|
@ -14,7 +14,6 @@ __all__ = [
|
||||||
"EntityRecognizer",
|
"EntityRecognizer",
|
||||||
"EntityLinker",
|
"EntityLinker",
|
||||||
"TextCategorizer",
|
"TextCategorizer",
|
||||||
"Tensorizer",
|
|
||||||
"Tok2Vec",
|
"Tok2Vec",
|
||||||
"Pipe",
|
"Pipe",
|
||||||
"Morphologizer",
|
"Morphologizer",
|
||||||
|
|
|
@ -63,16 +63,6 @@ def default_tagger():
|
||||||
return util.load_config(loc, create_objects=True)["model"]
|
return util.load_config(loc, create_objects=True)["model"]
|
||||||
|
|
||||||
|
|
||||||
def default_tensorizer_config():
|
|
||||||
loc = Path(__file__).parent / "tensorizer_defaults.cfg"
|
|
||||||
return util.load_config(loc, create_objects=False)
|
|
||||||
|
|
||||||
|
|
||||||
def default_tensorizer():
|
|
||||||
loc = Path(__file__).parent / "tensorizer_defaults.cfg"
|
|
||||||
return util.load_config(loc, create_objects=True)["model"]
|
|
||||||
|
|
||||||
|
|
||||||
def default_textcat_config():
|
def default_textcat_config():
|
||||||
loc = Path(__file__).parent / "textcat_defaults.cfg"
|
loc = Path(__file__).parent / "textcat_defaults.cfg"
|
||||||
return util.load_config(loc, create_objects=False)
|
return util.load_config(loc, create_objects=False)
|
||||||
|
|
|
@ -10,3 +10,4 @@ embed_size = 300
|
||||||
window_size = 1
|
window_size = 1
|
||||||
maxout_pieces = 3
|
maxout_pieces = 3
|
||||||
subword_features = true
|
subword_features = true
|
||||||
|
dropout = null
|
||||||
|
|
|
@ -11,3 +11,4 @@ window_size = 1
|
||||||
maxout_pieces = 3
|
maxout_pieces = 3
|
||||||
nM = 64
|
nM = 64
|
||||||
nC = 8
|
nC = 8
|
||||||
|
dropout = null
|
||||||
|
|
|
@ -13,3 +13,4 @@ embed_size = 2000
|
||||||
window_size = 1
|
window_size = 1
|
||||||
maxout_pieces = 3
|
maxout_pieces = 3
|
||||||
subword_features = true
|
subword_features = true
|
||||||
|
dropout = null
|
||||||
|
|
|
@ -13,3 +13,4 @@ embed_size = 2000
|
||||||
window_size = 1
|
window_size = 1
|
||||||
maxout_pieces = 3
|
maxout_pieces = 3
|
||||||
subword_features = true
|
subword_features = true
|
||||||
|
dropout = null
|
||||||
|
|
|
@ -10,3 +10,4 @@ embed_size = 2000
|
||||||
window_size = 1
|
window_size = 1
|
||||||
maxout_pieces = 2
|
maxout_pieces = 2
|
||||||
subword_features = true
|
subword_features = true
|
||||||
|
dropout = null
|
||||||
|
|
|
@ -10,3 +10,4 @@ embed_size = 7000
|
||||||
window_size = 1
|
window_size = 1
|
||||||
maxout_pieces = 3
|
maxout_pieces = 3
|
||||||
subword_features = true
|
subword_features = true
|
||||||
|
dropout = null
|
||||||
|
|
|
@ -10,3 +10,4 @@ embed_size = 2000
|
||||||
window_size = 1
|
window_size = 1
|
||||||
maxout_pieces = 3
|
maxout_pieces = 3
|
||||||
subword_features = true
|
subword_features = true
|
||||||
|
dropout = null
|
||||||
|
|
|
@ -1,4 +0,0 @@
|
||||||
[model]
|
|
||||||
@architectures = "spacy.Tensorizer.v1"
|
|
||||||
input_size=96
|
|
||||||
output_size=300
|
|
|
@ -11,3 +11,4 @@ embed_size = 2000
|
||||||
window_size = 1
|
window_size = 1
|
||||||
maxout_pieces = 3
|
maxout_pieces = 3
|
||||||
subword_features = true
|
subword_features = true
|
||||||
|
dropout = null
|
||||||
|
|
|
@ -7,3 +7,4 @@ conv_depth = 2
|
||||||
embed_size = 2000
|
embed_size = 2000
|
||||||
window_size = 1
|
window_size = 1
|
||||||
ngram_size = 1
|
ngram_size = 1
|
||||||
|
dropout = null
|
||||||
|
|
|
@ -7,3 +7,4 @@ embed_size = 2000
|
||||||
window_size = 1
|
window_size = 1
|
||||||
maxout_pieces = 3
|
maxout_pieces = 3
|
||||||
subword_features = true
|
subword_features = true
|
||||||
|
dropout = null
|
||||||
|
|
|
@ -44,8 +44,8 @@ class SentenceSegmenter(object):
|
||||||
class SimilarityHook(Pipe):
|
class SimilarityHook(Pipe):
|
||||||
"""
|
"""
|
||||||
Experimental: A pipeline component to install a hook for supervised
|
Experimental: A pipeline component to install a hook for supervised
|
||||||
similarity into `Doc` objects. Requires a `Tensorizer` to pre-process
|
similarity into `Doc` objects.
|
||||||
documents. The similarity model can be any object obeying the Thinc `Model`
|
The similarity model can be any object obeying the Thinc `Model`
|
||||||
interface. By default, the model concatenates the elementwise mean and
|
interface. By default, the model concatenates the elementwise mean and
|
||||||
elementwise max of the two tensors, and compares them using the
|
elementwise max of the two tensors, and compares them using the
|
||||||
Cauchy-like similarity function from Chen (2013):
|
Cauchy-like similarity function from Chen (2013):
|
||||||
|
@ -82,7 +82,7 @@ class SimilarityHook(Pipe):
|
||||||
sims, bp_sims = self.model.begin_update(doc1_doc2)
|
sims, bp_sims = self.model.begin_update(doc1_doc2)
|
||||||
|
|
||||||
def begin_training(self, _=tuple(), pipeline=None, sgd=None, **kwargs):
|
def begin_training(self, _=tuple(), pipeline=None, sgd=None, **kwargs):
|
||||||
"""Allocate model, using width from tensorizer in pipeline.
|
"""Allocate model, using nO from the first model in the pipeline.
|
||||||
|
|
||||||
gold_tuples (iterable): Gold-standard training data.
|
gold_tuples (iterable): Gold-standard training data.
|
||||||
pipeline (list): The pipeline the model is part of.
|
pipeline (list): The pipeline the model is part of.
|
||||||
|
|
|
@ -16,7 +16,7 @@ from ..morphology cimport Morphology
|
||||||
from ..vocab cimport Vocab
|
from ..vocab cimport Vocab
|
||||||
|
|
||||||
from .defaults import default_tagger, default_parser, default_ner, default_textcat
|
from .defaults import default_tagger, default_parser, default_ner, default_textcat
|
||||||
from .defaults import default_nel, default_senter, default_tensorizer
|
from .defaults import default_nel, default_senter
|
||||||
from .functions import merge_subtokens
|
from .functions import merge_subtokens
|
||||||
from ..language import Language, component
|
from ..language import Language, component
|
||||||
from ..syntax import nonproj
|
from ..syntax import nonproj
|
||||||
|
@ -238,138 +238,6 @@ class Pipe(object):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
@component("tensorizer", assigns=["doc.tensor"], default_model=default_tensorizer)
|
|
||||||
class Tensorizer(Pipe):
|
|
||||||
"""Pre-train position-sensitive vectors for tokens."""
|
|
||||||
|
|
||||||
def __init__(self, vocab, model, **cfg):
|
|
||||||
"""Construct a new statistical model. Weights are not allocated on
|
|
||||||
initialisation.
|
|
||||||
|
|
||||||
vocab (Vocab): A `Vocab` instance. The model must share the same
|
|
||||||
`Vocab` instance with the `Doc` objects it will process.
|
|
||||||
**cfg: Config parameters.
|
|
||||||
"""
|
|
||||||
self.vocab = vocab
|
|
||||||
self.model = model
|
|
||||||
self.input_models = []
|
|
||||||
self.cfg = dict(cfg)
|
|
||||||
|
|
||||||
def __call__(self, example):
|
|
||||||
"""Add context-sensitive vectors to a `Doc`, e.g. from a CNN or LSTM
|
|
||||||
model. Vectors are set to the `Doc.tensor` attribute.
|
|
||||||
|
|
||||||
docs (Doc or iterable): One or more documents to add vectors to.
|
|
||||||
RETURNS (dict or None): Intermediate computations.
|
|
||||||
"""
|
|
||||||
doc = self._get_doc(example)
|
|
||||||
tokvecses = self.predict([doc])
|
|
||||||
self.set_annotations([doc], tokvecses)
|
|
||||||
if isinstance(example, Example):
|
|
||||||
example.doc = doc
|
|
||||||
return example
|
|
||||||
return doc
|
|
||||||
|
|
||||||
def pipe(self, stream, batch_size=128, n_threads=-1, as_example=False):
|
|
||||||
"""Process `Doc` objects as a stream.
|
|
||||||
|
|
||||||
stream (iterator): A sequence of `Doc` or `Example` objects to process.
|
|
||||||
batch_size (int): Number of `Doc` or `Example` objects to group.
|
|
||||||
YIELDS (iterator): A sequence of `Doc` or `Example` objects, in order of input.
|
|
||||||
"""
|
|
||||||
for examples in util.minibatch(stream, size=batch_size):
|
|
||||||
docs = [self._get_doc(ex) for ex in examples]
|
|
||||||
tensors = self.predict(docs)
|
|
||||||
self.set_annotations(docs, tensors)
|
|
||||||
|
|
||||||
if as_example:
|
|
||||||
for ex, doc in zip(examples, docs):
|
|
||||||
ex.doc = doc
|
|
||||||
yield ex
|
|
||||||
else:
|
|
||||||
yield from docs
|
|
||||||
|
|
||||||
def predict(self, docs):
|
|
||||||
"""Return a single tensor for a batch of documents.
|
|
||||||
|
|
||||||
docs (iterable): A sequence of `Doc` objects.
|
|
||||||
RETURNS (object): Vector representations for each token in the docs.
|
|
||||||
"""
|
|
||||||
inputs = self.model.ops.flatten([doc.tensor for doc in docs])
|
|
||||||
outputs = self.model(inputs)
|
|
||||||
return self.model.ops.unflatten(outputs, [len(d) for d in docs])
|
|
||||||
|
|
||||||
def set_annotations(self, docs, tensors):
|
|
||||||
"""Set the tensor attribute for a batch of documents.
|
|
||||||
|
|
||||||
docs (iterable): A sequence of `Doc` objects.
|
|
||||||
tensors (object): Vector representation for each token in the docs.
|
|
||||||
"""
|
|
||||||
for doc, tensor in zip(docs, tensors):
|
|
||||||
if tensor.shape[0] != len(doc):
|
|
||||||
raise ValueError(Errors.E076.format(rows=tensor.shape[0], words=len(doc)))
|
|
||||||
doc.tensor = tensor
|
|
||||||
|
|
||||||
def update(self, examples, state=None, drop=0.0, set_annotations=False, sgd=None, losses=None):
|
|
||||||
"""Update the model.
|
|
||||||
|
|
||||||
docs (iterable): A batch of `Doc` objects.
|
|
||||||
golds (iterable): A batch of `GoldParse` objects.
|
|
||||||
drop (float): The dropout rate.
|
|
||||||
sgd (callable): An optimizer.
|
|
||||||
RETURNS (dict): Results from the update.
|
|
||||||
"""
|
|
||||||
examples = Example.to_example_objects(examples)
|
|
||||||
inputs = []
|
|
||||||
bp_inputs = []
|
|
||||||
set_dropout_rate(self.model, drop)
|
|
||||||
for tok2vec in self.input_models:
|
|
||||||
set_dropout_rate(tok2vec, drop)
|
|
||||||
tensor, bp_tensor = tok2vec.begin_update([ex.doc for ex in examples])
|
|
||||||
inputs.append(tensor)
|
|
||||||
bp_inputs.append(bp_tensor)
|
|
||||||
inputs = self.model.ops.xp.hstack(inputs)
|
|
||||||
scores, bp_scores = self.model.begin_update(inputs)
|
|
||||||
loss, d_scores = self.get_loss(examples, scores)
|
|
||||||
d_inputs = bp_scores(d_scores, sgd=sgd)
|
|
||||||
d_inputs = self.model.ops.xp.split(d_inputs, len(self.input_models), axis=1)
|
|
||||||
for d_input, bp_input in zip(d_inputs, bp_inputs):
|
|
||||||
bp_input(d_input)
|
|
||||||
if sgd is not None:
|
|
||||||
for tok2vec in self.input_models:
|
|
||||||
tok2vec.finish_update(sgd)
|
|
||||||
self.model.finish_update(sgd)
|
|
||||||
if losses is not None:
|
|
||||||
losses.setdefault(self.name, 0.0)
|
|
||||||
losses[self.name] += loss
|
|
||||||
return loss
|
|
||||||
|
|
||||||
def get_loss(self, examples, prediction):
|
|
||||||
examples = Example.to_example_objects(examples)
|
|
||||||
ids = self.model.ops.flatten([ex.doc.to_array(ID).ravel() for ex in examples])
|
|
||||||
target = self.vocab.vectors.data[ids]
|
|
||||||
d_scores = (prediction - target) / prediction.shape[0]
|
|
||||||
loss = (d_scores ** 2).sum()
|
|
||||||
return loss, d_scores
|
|
||||||
|
|
||||||
def begin_training(self, get_examples=lambda: [], pipeline=None, sgd=None, **kwargs):
|
|
||||||
"""Allocate models, pre-process training data and acquire an
|
|
||||||
optimizer.
|
|
||||||
|
|
||||||
get_examples (iterable): Gold-standard training data.
|
|
||||||
pipeline (list): The pipeline the model is part of.
|
|
||||||
"""
|
|
||||||
if pipeline is not None:
|
|
||||||
for name, model in pipeline:
|
|
||||||
if model.has_ref("tok2vec"):
|
|
||||||
self.input_models.append(model.get_ref("tok2vec"))
|
|
||||||
self.model.initialize()
|
|
||||||
link_vectors_to_models(self.vocab)
|
|
||||||
if sgd is None:
|
|
||||||
sgd = self.create_optimizer()
|
|
||||||
return sgd
|
|
||||||
|
|
||||||
|
|
||||||
@component("tagger", assigns=["token.tag", "token.pos", "token.lemma"], default_model=default_tagger)
|
@component("tagger", assigns=["token.tag", "token.pos", "token.lemma"], default_model=default_tagger)
|
||||||
class Tagger(Pipe):
|
class Tagger(Pipe):
|
||||||
"""Pipeline component for part-of-speech tagging.
|
"""Pipeline component for part-of-speech tagging.
|
||||||
|
@ -1708,4 +1576,4 @@ def ner_factory(nlp, model, **cfg):
|
||||||
warnings.warn(Warnings.W098.format(name="ner"))
|
warnings.warn(Warnings.W098.format(name="ner"))
|
||||||
return EntityRecognizer.from_nlp(nlp, model, **cfg)
|
return EntityRecognizer.from_nlp(nlp, model, **cfg)
|
||||||
|
|
||||||
__all__ = ["Tagger", "DependencyParser", "EntityRecognizer", "Tensorizer", "TextCategorizer", "EntityLinker", "Sentencizer", "SentenceRecognizer"]
|
__all__ = ["Tagger", "DependencyParser", "EntityRecognizer", "TextCategorizer", "EntityLinker", "Sentencizer", "SentenceRecognizer"]
|
||||||
|
|
|
@ -123,9 +123,9 @@ def test_overfitting_IO():
|
||||||
{"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": True, "ngram_size": 4, "no_output_layer": False},
|
{"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": True, "ngram_size": 4, "no_output_layer": False},
|
||||||
{"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": False, "ngram_size": 3, "no_output_layer": True},
|
{"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": False, "ngram_size": 3, "no_output_layer": True},
|
||||||
{"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": True, "ngram_size": 2, "no_output_layer": True},
|
{"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": True, "ngram_size": 2, "no_output_layer": True},
|
||||||
{"@architectures": "spacy.TextCat.v1", "exclusive_classes": False, "ngram_size": 1, "pretrained_vectors": False, "width": 64, "conv_depth": 2, "embed_size": 2000, "window_size": 2},
|
{"@architectures": "spacy.TextCat.v1", "exclusive_classes": False, "ngram_size": 1, "pretrained_vectors": False, "width": 64, "conv_depth": 2, "embed_size": 2000, "window_size": 2, "dropout": None},
|
||||||
{"@architectures": "spacy.TextCat.v1", "exclusive_classes": True, "ngram_size": 5, "pretrained_vectors": False, "width": 128, "conv_depth": 2, "embed_size": 2000, "window_size": 1},
|
{"@architectures": "spacy.TextCat.v1", "exclusive_classes": True, "ngram_size": 5, "pretrained_vectors": False, "width": 128, "conv_depth": 2, "embed_size": 2000, "window_size": 1, "dropout": None},
|
||||||
{"@architectures": "spacy.TextCat.v1", "exclusive_classes": True, "ngram_size": 2, "pretrained_vectors": False, "width": 32, "conv_depth": 3, "embed_size": 500, "window_size": 3},
|
{"@architectures": "spacy.TextCat.v1", "exclusive_classes": True, "ngram_size": 2, "pretrained_vectors": False, "width": 32, "conv_depth": 3, "embed_size": 500, "window_size": 3, "dropout": None},
|
||||||
{"@architectures": "spacy.TextCatCNN.v1", "tok2vec": default_tok2vec(), "exclusive_classes": True},
|
{"@architectures": "spacy.TextCatCNN.v1", "tok2vec": default_tok2vec(), "exclusive_classes": True},
|
||||||
{"@architectures": "spacy.TextCatCNN.v1", "tok2vec": default_tok2vec(), "exclusive_classes": False},
|
{"@architectures": "spacy.TextCatCNN.v1", "tok2vec": default_tok2vec(), "exclusive_classes": False},
|
||||||
],
|
],
|
||||||
|
|
|
@ -24,6 +24,7 @@ window_size = 1
|
||||||
embed_size = 2000
|
embed_size = 2000
|
||||||
maxout_pieces = 3
|
maxout_pieces = 3
|
||||||
subword_features = true
|
subword_features = true
|
||||||
|
dropout = null
|
||||||
|
|
||||||
[nlp.pipeline.tagger]
|
[nlp.pipeline.tagger]
|
||||||
factory = "tagger"
|
factory = "tagger"
|
||||||
|
@ -53,6 +54,7 @@ embed_size = 5555
|
||||||
window_size = 1
|
window_size = 1
|
||||||
maxout_pieces = 7
|
maxout_pieces = 7
|
||||||
subword_features = false
|
subword_features = false
|
||||||
|
dropout = null
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@ -70,6 +72,7 @@ def my_parser():
|
||||||
nC=8,
|
nC=8,
|
||||||
conv_depth=2,
|
conv_depth=2,
|
||||||
bilstm_depth=0,
|
bilstm_depth=0,
|
||||||
|
dropout=None,
|
||||||
)
|
)
|
||||||
parser = build_tb_parser_model(
|
parser = build_tb_parser_model(
|
||||||
tok2vec=tok2vec, nr_feature_tokens=7, hidden_width=65, maxout_pieces=5
|
tok2vec=tok2vec, nr_feature_tokens=7, hidden_width=65, maxout_pieces=5
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import pytest
|
import pytest
|
||||||
from spacy.pipeline import Tagger, DependencyParser, EntityRecognizer
|
from spacy.pipeline import Tagger, DependencyParser, EntityRecognizer
|
||||||
from spacy.pipeline import Tensorizer, TextCategorizer, SentenceRecognizer
|
from spacy.pipeline import TextCategorizer, SentenceRecognizer
|
||||||
from spacy.pipeline.defaults import default_parser, default_tensorizer, default_tagger
|
from spacy.pipeline.defaults import default_parser, default_tagger
|
||||||
from spacy.pipeline.defaults import default_textcat, default_senter
|
from spacy.pipeline.defaults import default_textcat, default_senter
|
||||||
|
|
||||||
from ..util import make_tempdir
|
from ..util import make_tempdir
|
||||||
|
@ -95,24 +95,6 @@ def test_serialize_tagger_roundtrip_disk(en_vocab, taggers):
|
||||||
assert tagger1_d.to_bytes() == tagger2_d.to_bytes()
|
assert tagger1_d.to_bytes() == tagger2_d.to_bytes()
|
||||||
|
|
||||||
|
|
||||||
def test_serialize_tensorizer_roundtrip_bytes(en_vocab):
|
|
||||||
tensorizer = Tensorizer(en_vocab, default_tensorizer())
|
|
||||||
tensorizer_b = tensorizer.to_bytes(exclude=["vocab"])
|
|
||||||
new_tensorizer = Tensorizer(en_vocab, default_tensorizer()).from_bytes(tensorizer_b)
|
|
||||||
assert new_tensorizer.to_bytes(exclude=["vocab"]) == tensorizer_b
|
|
||||||
|
|
||||||
|
|
||||||
def test_serialize_tensorizer_roundtrip_disk(en_vocab):
|
|
||||||
tensorizer = Tensorizer(en_vocab, default_tensorizer())
|
|
||||||
with make_tempdir() as d:
|
|
||||||
file_path = d / "tensorizer"
|
|
||||||
tensorizer.to_disk(file_path)
|
|
||||||
tensorizer_d = Tensorizer(en_vocab, default_tensorizer()).from_disk(file_path)
|
|
||||||
assert tensorizer.to_bytes(exclude=["vocab"]) == tensorizer_d.to_bytes(
|
|
||||||
exclude=["vocab"]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_serialize_textcat_empty(en_vocab):
|
def test_serialize_textcat_empty(en_vocab):
|
||||||
# See issue #1105
|
# See issue #1105
|
||||||
textcat = TextCategorizer(
|
textcat = TextCategorizer(
|
||||||
|
|
|
@ -15,7 +15,7 @@ def test_empty_doc():
|
||||||
vocab = Vocab()
|
vocab = Vocab()
|
||||||
doc = Doc(vocab, words=[])
|
doc = Doc(vocab, words=[])
|
||||||
# TODO: fix tok2vec arguments
|
# TODO: fix tok2vec arguments
|
||||||
tok2vec = build_Tok2Vec_model(width, embed_size)
|
tok2vec = build_Tok2Vec_model(width, embed_size, dropout=None)
|
||||||
vectors, backprop = tok2vec.begin_update([doc])
|
vectors, backprop = tok2vec.begin_update([doc])
|
||||||
assert len(vectors) == 1
|
assert len(vectors) == 1
|
||||||
assert vectors[0].shape == (0, width)
|
assert vectors[0].shape == (0, width)
|
||||||
|
@ -38,6 +38,7 @@ def test_tok2vec_batch_sizes(batch_size, width, embed_size):
|
||||||
char_embed=False,
|
char_embed=False,
|
||||||
nM=64,
|
nM=64,
|
||||||
nC=8,
|
nC=8,
|
||||||
|
dropout=None,
|
||||||
)
|
)
|
||||||
tok2vec.initialize()
|
tok2vec.initialize()
|
||||||
vectors, backprop = tok2vec.begin_update(batch)
|
vectors, backprop = tok2vec.begin_update(batch)
|
||||||
|
@ -50,14 +51,14 @@ def test_tok2vec_batch_sizes(batch_size, width, embed_size):
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"tok2vec_config",
|
"tok2vec_config",
|
||||||
[
|
[
|
||||||
{"width": 8, "embed_size": 100, "char_embed": False, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 1, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": True},
|
{"width": 8, "embed_size": 100, "char_embed": False, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 1, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": True, "dropout": None},
|
||||||
{"width": 8, "embed_size": 100, "char_embed": True, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 1, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": True},
|
{"width": 8, "embed_size": 100, "char_embed": True, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 1, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": True, "dropout": None},
|
||||||
{"width": 8, "embed_size": 100, "char_embed": False, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 1, "conv_depth": 6, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": True},
|
{"width": 8, "embed_size": 100, "char_embed": False, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 1, "conv_depth": 6, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": True, "dropout": None},
|
||||||
{"width": 8, "embed_size": 100, "char_embed": False, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 1, "conv_depth": 6, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": True},
|
{"width": 8, "embed_size": 100, "char_embed": False, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 1, "conv_depth": 6, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": True, "dropout": None},
|
||||||
{"width": 8, "embed_size": 100, "char_embed": False, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 1, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": False},
|
{"width": 8, "embed_size": 100, "char_embed": False, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 1, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": False, "dropout": None},
|
||||||
{"width": 8, "embed_size": 100, "char_embed": False, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 3, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": False},
|
{"width": 8, "embed_size": 100, "char_embed": False, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 3, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": False, "dropout": None},
|
||||||
{"width": 8, "embed_size": 100, "char_embed": True, "nM": 81, "nC": 8, "pretrained_vectors": None, "window_size": 3, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": False},
|
{"width": 8, "embed_size": 100, "char_embed": True, "nM": 81, "nC": 8, "pretrained_vectors": None, "window_size": 3, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": False, "dropout": None},
|
||||||
{"width": 8, "embed_size": 100, "char_embed": True, "nM": 81, "nC": 9, "pretrained_vectors": None, "window_size": 3, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": False},
|
{"width": 8, "embed_size": 100, "char_embed": True, "nM": 81, "nC": 9, "pretrained_vectors": None, "window_size": 3, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": False, "dropout": None},
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
Loading…
Reference in New Issue
Block a user