Make pretrain script work with stream from stdin

This commit is contained in:
Matthew Honnibal 2018-11-15 22:44:07 +00:00
parent 8fdb9bc278
commit 3e7b214e57
2 changed files with 29 additions and 12 deletions

View File

@ -20,10 +20,11 @@ import numpy
import time
import ujson as json
from pathlib import Path
import sys
import spacy
from spacy.attrs import ID
from spacy.util import minibatch, use_gpu, compounding, ensure_path
from spacy.util import minibatch_by_words, use_gpu, compounding, ensure_path
from spacy._ml import Tok2Vec, flatten, chain, zero_init, create_default_optimizer
from thinc.v2v import Affine
@ -45,9 +46,13 @@ def load_texts(path):
'''
path = ensure_path(path)
with path.open('r', encoding='utf8') as file_:
for line in file_:
data = json.loads(line)
yield data['text']
texts = [json.loads(line)['text'] for line in file_]
random.shuffle(texts)
return texts
def stream_texts():
for line in sys.stdin:
yield json.loads(line)['text']
def make_update(model, docs, optimizer, drop=0.):
@ -102,16 +107,19 @@ def create_pretraining_model(nlp, tok2vec):
class ProgressTracker(object):
def __init__(self, frequency=10000):
def __init__(self, frequency=100000):
self.loss = 0.
self.nr_word = 0
self.words_per_epoch = Counter()
self.frequency = frequency
self.last_time = time.time()
self.last_update = 0
def update(self, epoch, loss, docs):
self.loss += loss
self.nr_word += sum(len(doc) for doc in docs)
words_in_batch = sum(len(doc) for doc in docs)
self.words_per_epoch[epoch] += words_in_batch
self.nr_word += words_in_batch
words_since_update = self.nr_word - self.last_update
if words_since_update >= self.frequency:
wps = words_since_update / (time.time() - self.last_time)
@ -170,19 +178,22 @@ def pretrain(texts_loc, vectors_model, output_dir, width=128, depth=4,
model = create_pretraining_model(nlp, tok2vec)
optimizer = create_default_optimizer(model.ops)
tracker = ProgressTracker()
texts = list(load_texts(texts_loc))
print('Epoch', '#Words', 'Loss', 'w/s')
texts = stream_texts() if text_loc == '-' else load_texts(texts_loc)
for epoch in range(nr_iter):
random.shuffle(texts)
for batch in minibatch(texts):
for batch in minibatch_by_words(texts, tuples=False, size=50000):
docs = [nlp.make_doc(text) for text in batch]
loss = make_update(model, docs, optimizer, drop=dropout)
progress = tracker.update(epoch, loss, docs)
if progress:
print(*progress)
if texts_loc == '-' and progress.words_per_epoch[epoch] >= 10**7:
break
with model.use_params(optimizer.averages):
with (output_dir / ('model%d.bin' % epoch)).open('wb') as file_:
file_.write(tok2vec.to_bytes())
with (output_dir / 'log.jsonl').open('a') as file_:
file_.write(json.dumps({'nr_word': tracker.nr_word,
'loss': tracker.loss, 'epoch': epoch}))
if texts_loc != '-':
texts = load_texts(texts_loc)

View File

@ -465,7 +465,7 @@ def decaying(start, stop, decay):
nr_upd += 1
def minibatch_by_words(items, size, count_words=len):
def minibatch_by_words(items, size, tuples=True, count_words=len):
'''Create minibatches of a given number of words.'''
if isinstance(size, int):
size_ = itertools.repeat(size)
@ -477,13 +477,19 @@ def minibatch_by_words(items, size, count_words=len):
batch = []
while batch_size >= 0:
try:
doc, gold = next(items)
if tuples:
doc, gold = next(items)
else:
doc = next(items)
except StopIteration:
if batch:
yield batch
return
batch_size -= count_words(doc)
batch.append((doc, gold))
if tuples:
batch.append((doc, gold))
else:
batch.append(doc)
if batch:
yield batch