💫 Feature/improve pretraining (#2971)

* Improve spacy pretrain script

* Implement BERT-style 'masked language model' objective. Much better
results.

* Improve logging.

* Add length cap for documents, to avoid memory errors.

* Require thinc 7.0.0.dev1

* Require thinc 7.0.0.dev1

* Add argument for using pretrained vectors

* Fix defaults

* Fix syntax error

* Improve spacy pretrain script

* Implement BERT-style 'masked language model' objective. Much better
results.

* Improve logging.

* Add length cap for documents, to avoid memory errors.

* Require thinc 7.0.0.dev1

* Require thinc 7.0.0.dev1

* Add argument for using pretrained vectors

* Fix defaults

* Fix syntax error

* Tweak pretraining script

* Fix data limits in spacy.gold

* Fix pretrain script
This commit is contained in:
Matthew Honnibal 2018-11-28 18:04:58 +01:00 committed by GitHub
parent 0fdb25b958
commit 61e435610e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 123 additions and 22 deletions

View File

@ -2,7 +2,7 @@ cython>=0.25
numpy>=1.15.0 numpy>=1.15.0
cymem>=2.0.2,<2.1.0 cymem>=2.0.2,<2.1.0
preshed>=2.0.1,<2.1.0 preshed>=2.0.1,<2.1.0
thinc==7.0.0.dev0 thinc==7.0.0.dev1
blis>=0.2.2,<0.3.0 blis>=0.2.2,<0.3.0
murmurhash>=0.28.0,<1.1.0 murmurhash>=0.28.0,<1.1.0
cytoolz>=0.9.0,<0.10.0 cytoolz>=0.9.0,<0.10.0

View File

@ -200,7 +200,7 @@ def setup_package():
"murmurhash>=0.28.0,<1.1.0", "murmurhash>=0.28.0,<1.1.0",
"cymem>=2.0.2,<2.1.0", "cymem>=2.0.2,<2.1.0",
"preshed>=2.0.1,<2.1.0", "preshed>=2.0.1,<2.1.0",
"thinc==7.0.0.dev0", "thinc==7.0.0.dev1",
"blis>=0.2.2,<0.3.0", "blis>=0.2.2,<0.3.0",
"plac<1.0.0,>=0.9.6", "plac<1.0.0,>=0.9.6",
"ujson>=1.35", "ujson>=1.35",

View File

@ -24,10 +24,12 @@ import sys
from collections import Counter from collections import Counter
import spacy import spacy
from spacy.attrs import ID from spacy.tokens import Doc
from spacy.attrs import ID, HEAD
from spacy.util import minibatch, minibatch_by_words, use_gpu, compounding, ensure_path from spacy.util import minibatch, minibatch_by_words, use_gpu, compounding, ensure_path
from spacy._ml import Tok2Vec, flatten, chain, zero_init, create_default_optimizer from spacy._ml import Tok2Vec, flatten, chain, zero_init, create_default_optimizer
from thinc.v2v import Affine from thinc.v2v import Affine
from thinc.api import wrap
def prefer_gpu(): def prefer_gpu():
@ -47,13 +49,14 @@ def load_texts(path):
''' '''
path = ensure_path(path) path = ensure_path(path)
with path.open('r', encoding='utf8') as file_: with path.open('r', encoding='utf8') as file_:
texts = [json.loads(line)['text'] for line in file_] texts = [json.loads(line) for line in file_]
random.shuffle(texts) random.shuffle(texts)
return texts return texts
def stream_texts(): def stream_texts():
for line in sys.stdin: for line in sys.stdin:
yield json.loads(line)['text'] yield json.loads(line)
def make_update(model, docs, optimizer, drop=0.): def make_update(model, docs, optimizer, drop=0.):
@ -65,11 +68,33 @@ def make_update(model, docs, optimizer, drop=0.):
RETURNS loss: A float for the loss. RETURNS loss: A float for the loss.
""" """
predictions, backprop = model.begin_update(docs, drop=drop) predictions, backprop = model.begin_update(docs, drop=drop)
loss, gradients = get_vectors_loss(model.ops, docs, predictions) gradients = get_vectors_loss(model.ops, docs, predictions)
backprop(gradients, sgd=optimizer) backprop(gradients, sgd=optimizer)
# Don't want to return a cupy object here
# The gradients are modified in-place by the BERT MLM,
# so we get an accurate loss
loss = float((gradients**2).mean())
return loss return loss
def make_docs(nlp, batch):
docs = []
for record in batch:
text = record["text"]
if "tokens" in record:
doc = Doc(nlp.vocab, words=record["tokens"])
else:
doc = nlp.make_doc(text)
if "heads" in record:
heads = record["heads"]
heads = numpy.asarray(heads, dtype="uint64")
heads = heads.reshape((len(doc), 1))
doc = doc.from_array([HEAD], heads)
if len(doc) >= 1 and len(doc) < 200:
docs.append(doc)
return docs
def get_vectors_loss(ops, docs, prediction): def get_vectors_loss(ops, docs, prediction):
"""Compute a mean-squared error loss between the documents' vectors and """Compute a mean-squared error loss between the documents' vectors and
the prediction. the prediction.
@ -84,10 +109,8 @@ def get_vectors_loss(ops, docs, prediction):
# and look them up all at once. This prevents data copying. # and look them up all at once. This prevents data copying.
ids = ops.flatten([doc.to_array(ID).ravel() for doc in docs]) ids = ops.flatten([doc.to_array(ID).ravel() for doc in docs])
target = docs[0].vocab.vectors.data[ids] target = docs[0].vocab.vectors.data[ids]
d_scores = (prediction - target) / prediction.shape[0] d_scores = prediction - target
# Don't want to return a cupy object here return d_scores
loss = float((d_scores**2).sum())
return loss, d_scores
def create_pretraining_model(nlp, tok2vec): def create_pretraining_model(nlp, tok2vec):
@ -107,15 +130,77 @@ def create_pretraining_model(nlp, tok2vec):
tok2vec, tok2vec,
output_layer output_layer
) )
model = masked_language_model(nlp.vocab, model)
model.tok2vec = tok2vec model.tok2vec = tok2vec
model.output_layer = output_layer model.output_layer = output_layer
model.begin_training([nlp.make_doc('Give it a doc to infer shapes')]) model.begin_training([nlp.make_doc('Give it a doc to infer shapes')])
return model return model
def masked_language_model(vocab, model, mask_prob=0.15):
'''Convert a model into a BERT-style masked language model'''
vocab_words = [lex.text for lex in vocab if lex.prob != 0.0]
vocab_probs = [lex.prob for lex in vocab if lex.prob != 0.0]
vocab_words = vocab_words[:10000]
vocab_probs = vocab_probs[:10000]
vocab_probs = numpy.exp(numpy.array(vocab_probs, dtype='f'))
vocab_probs /= vocab_probs.sum()
def mlm_forward(docs, drop=0.):
mask, docs = apply_mask(docs, vocab_words, vocab_probs,
mask_prob=mask_prob)
mask = model.ops.asarray(mask).reshape((mask.shape[0], 1))
output, backprop = model.begin_update(docs, drop=drop)
def mlm_backward(d_output, sgd=None):
d_output *= 1-mask
return backprop(d_output, sgd=sgd)
return output, mlm_backward
return wrap(mlm_forward, model)
def apply_mask(docs, vocab_texts, vocab_probs, mask_prob=0.15):
N = sum(len(doc) for doc in docs)
mask = numpy.random.uniform(0., 1.0, (N,))
mask = mask >= mask_prob
i = 0
masked_docs = []
for doc in docs:
words = []
for token in doc:
if not mask[i]:
word = replace_word(token.text, vocab_texts, vocab_probs)
else:
word = token.text
words.append(word)
i += 1
spaces = [bool(w.whitespace_) for w in doc]
# NB: If you change this implementation to instead modify
# the docs in place, take care that the IDs reflect the original
# words. Currently we use the original docs to make the vectors
# for the target, so we don't lose the original tokens. But if
# you modified the docs in place here, you would.
masked_docs.append(Doc(doc.vocab, words=words, spaces=spaces))
return mask, masked_docs
def replace_word(word, vocab_texts, vocab_probs, mask='[MASK]'):
roll = random.random()
if roll < 0.8:
return mask
elif roll < 0.9:
index = numpy.random.choice(len(vocab_texts), p=vocab_probs)
return vocab_texts[index]
else:
return word
class ProgressTracker(object): class ProgressTracker(object):
def __init__(self, frequency=100000): def __init__(self, frequency=100000):
self.loss = 0. self.loss = 0.0
self.prev_loss = 0.0
self.nr_word = 0 self.nr_word = 0
self.words_per_epoch = Counter() self.words_per_epoch = Counter()
self.frequency = frequency self.frequency = frequency
@ -132,7 +217,15 @@ class ProgressTracker(object):
wps = words_since_update / (time.time() - self.last_time) wps = words_since_update / (time.time() - self.last_time)
self.last_update = self.nr_word self.last_update = self.nr_word
self.last_time = time.time() self.last_time = time.time()
status = (epoch, self.nr_word, '%.5f' % self.loss, int(wps)) loss_per_word = self.loss - self.prev_loss
status = (
epoch,
self.nr_word,
"%.5f" % self.loss,
"%.4f" % loss_per_word,
int(wps),
)
self.prev_loss = float(self.loss)
return status return status
else: else:
return None return None
@ -145,12 +238,13 @@ class ProgressTracker(object):
width=("Width of CNN layers", "option", "cw", int), width=("Width of CNN layers", "option", "cw", int),
depth=("Depth of CNN layers", "option", "cd", int), depth=("Depth of CNN layers", "option", "cd", int),
embed_rows=("Embedding rows", "option", "er", int), embed_rows=("Embedding rows", "option", "er", int),
use_vectors=("Whether to use the static vectors as input features", "flag", "uv"),
dropout=("Dropout", "option", "d", float), dropout=("Dropout", "option", "d", float),
seed=("Seed for random number generators", "option", "s", float), seed=("Seed for random number generators", "option", "s", float),
nr_iter=("Number of iterations to pretrain", "option", "i", int), nr_iter=("Number of iterations to pretrain", "option", "i", int),
) )
def pretrain(texts_loc, vectors_model, output_dir, width=128, depth=4, def pretrain(texts_loc, vectors_model, output_dir, width=128, depth=4,
embed_rows=1000, dropout=0.2, nr_iter=10, seed=0): embed_rows=5000, use_vectors=False, dropout=0.2, nr_iter=100, seed=0):
""" """
Pre-train the 'token-to-vector' (tok2vec) layer of pipeline components, Pre-train the 'token-to-vector' (tok2vec) layer of pipeline components,
using an approximate language-modelling objective. Specifically, we load using an approximate language-modelling objective. Specifically, we load
@ -175,11 +269,13 @@ def pretrain(texts_loc, vectors_model, output_dir, width=128, depth=4,
with (output_dir / 'config.json').open('w') as file_: with (output_dir / 'config.json').open('w') as file_:
file_.write(json.dumps(config)) file_.write(json.dumps(config))
has_gpu = prefer_gpu() has_gpu = prefer_gpu()
print("Use GPU?", has_gpu)
nlp = spacy.load(vectors_model) nlp = spacy.load(vectors_model)
pretrained_vectors = None if not use_vectors else nlp.vocab.vectors.name
model = create_pretraining_model(nlp, model = create_pretraining_model(nlp,
Tok2Vec(width, embed_rows, Tok2Vec(width, embed_rows,
conv_depth=depth, conv_depth=depth,
pretrained_vectors=nlp.vocab.vectors.name, pretrained_vectors=pretrained_vectors,
bilstm_depth=0, # Requires PyTorch. Experimental. bilstm_depth=0, # Requires PyTorch. Experimental.
cnn_maxout_pieces=2, # You can try setting this higher cnn_maxout_pieces=2, # You can try setting this higher
subword_features=True)) # Set to False for character models, e.g. Chinese subword_features=True)) # Set to False for character models, e.g. Chinese
@ -188,19 +284,19 @@ def pretrain(texts_loc, vectors_model, output_dir, width=128, depth=4,
print('Epoch', '#Words', 'Loss', 'w/s') print('Epoch', '#Words', 'Loss', 'w/s')
texts = stream_texts() if texts_loc == '-' else load_texts(texts_loc) texts = stream_texts() if texts_loc == '-' else load_texts(texts_loc)
for epoch in range(nr_iter): for epoch in range(nr_iter):
for batch in minibatch(texts, size=64): for batch in minibatch(texts, size=256):
docs = [nlp.make_doc(text) for text in batch] docs = make_docs(nlp, batch)
loss = make_update(model, docs, optimizer, drop=dropout) loss = make_update(model, docs, optimizer, drop=dropout)
progress = tracker.update(epoch, loss, docs) progress = tracker.update(epoch, loss, docs)
if progress: if progress:
print(*progress) print(*progress)
if texts_loc == '-' and tracker.words_per_epoch[epoch] >= 10**6: if texts_loc == '-' and tracker.words_per_epoch[epoch] >= 10**7:
break break
with model.use_params(optimizer.averages): with model.use_params(optimizer.averages):
with (output_dir / ('model%d.bin' % epoch)).open('wb') as file_: with (output_dir / ('model%d.bin' % epoch)).open('wb') as file_:
file_.write(model.tok2vec.to_bytes()) file_.write(model.tok2vec.to_bytes())
with (output_dir / 'log.jsonl').open('a') as file_: with (output_dir / 'log.jsonl').open('a') as file_:
file_.write(json.dumps({'nr_word': tracker.nr_word, file_.write(json.dumps({'nr_word': tracker.nr_word,
'loss': tracker.loss, 'epoch': epoch})) 'loss': tracker.loss, 'epoch': epoch}) + '\n')
if texts_loc != '-': if texts_loc != '-':
texts = load_texts(texts_loc) texts = load_texts(texts_loc)

View File

@ -25,6 +25,7 @@ from .compat import json_dumps
from libc.stdio cimport FILE, fopen, fclose, fread, fwrite, feof, fseek from libc.stdio cimport FILE, fopen, fclose, fread, fwrite, feof, fseek
def tags_to_entities(tags): def tags_to_entities(tags):
entities = [] entities = []
start = None start = None
@ -110,19 +111,23 @@ class GoldCorpus(object):
# Write temp directory with one doc per file, so we can shuffle # Write temp directory with one doc per file, so we can shuffle
# and stream # and stream
self.tmp_dir = Path(tempfile.mkdtemp()) self.tmp_dir = Path(tempfile.mkdtemp())
self.write_msgpack(self.tmp_dir / 'train', train) self.write_msgpack(self.tmp_dir / 'train', train, limit=self.limit)
self.write_msgpack(self.tmp_dir / 'dev', dev) self.write_msgpack(self.tmp_dir / 'dev', dev, limit=self.limit)
def __del__(self): def __del__(self):
shutil.rmtree(self.tmp_dir) shutil.rmtree(self.tmp_dir)
@staticmethod @staticmethod
def write_msgpack(directory, doc_tuples): def write_msgpack(directory, doc_tuples, limit=0):
if not directory.exists(): if not directory.exists():
directory.mkdir() directory.mkdir()
n = 0
for i, doc_tuple in enumerate(doc_tuples): for i, doc_tuple in enumerate(doc_tuples):
with open(directory / '{}.msg'.format(i), 'wb') as file_: with open(directory / '{}.msg'.format(i), 'wb') as file_:
msgpack.dump([doc_tuple], file_, use_bin_type=True) msgpack.dump([doc_tuple], file_, use_bin_type=True)
n += len(doc_tuple[1])
if limit and n >= limit:
break
@staticmethod @staticmethod
def walk_corpus(path): def walk_corpus(path):
@ -350,7 +355,7 @@ def _json_iterate(loc):
py_str = py_raw[start : i+1].decode('utf8') py_str = py_raw[start : i+1].decode('utf8')
try: try:
yield json.loads(py_str) yield json.loads(py_str)
except: except Exception:
print(py_str) print(py_str)
raise raise
start = -1 start = -1