Make pretrain script work with stream from stdin

This commit is contained in:
Matthew Honnibal 2018-11-15 22:44:07 +00:00
parent 8fdb9bc278
commit 3e7b214e57
2 changed files with 29 additions and 12 deletions

View File

@ -20,10 +20,11 @@ import numpy
import time import time
import ujson as json import ujson as json
from pathlib import Path from pathlib import Path
import sys
import spacy import spacy
from spacy.attrs import ID from spacy.attrs import ID
from spacy.util import minibatch, use_gpu, compounding, ensure_path from spacy.util import minibatch_by_words, use_gpu, compounding, ensure_path
from spacy._ml import Tok2Vec, flatten, chain, zero_init, create_default_optimizer from spacy._ml import Tok2Vec, flatten, chain, zero_init, create_default_optimizer
from thinc.v2v import Affine from thinc.v2v import Affine
@ -45,9 +46,13 @@ def load_texts(path):
''' '''
path = ensure_path(path) path = ensure_path(path)
with path.open('r', encoding='utf8') as file_: with path.open('r', encoding='utf8') as file_:
for line in file_: texts = [json.loads(line)['text'] for line in file_]
data = json.loads(line) random.shuffle(texts)
yield data['text'] return texts
def stream_texts():
for line in sys.stdin:
yield json.loads(line)['text']
def make_update(model, docs, optimizer, drop=0.): def make_update(model, docs, optimizer, drop=0.):
@ -102,16 +107,19 @@ def create_pretraining_model(nlp, tok2vec):
class ProgressTracker(object): class ProgressTracker(object):
def __init__(self, frequency=10000): def __init__(self, frequency=100000):
self.loss = 0. self.loss = 0.
self.nr_word = 0 self.nr_word = 0
self.words_per_epoch = Counter()
self.frequency = frequency self.frequency = frequency
self.last_time = time.time() self.last_time = time.time()
self.last_update = 0 self.last_update = 0
def update(self, epoch, loss, docs): def update(self, epoch, loss, docs):
self.loss += loss self.loss += loss
self.nr_word += sum(len(doc) for doc in docs) words_in_batch = sum(len(doc) for doc in docs)
self.words_per_epoch[epoch] += words_in_batch
self.nr_word += words_in_batch
words_since_update = self.nr_word - self.last_update words_since_update = self.nr_word - self.last_update
if words_since_update >= self.frequency: if words_since_update >= self.frequency:
wps = words_since_update / (time.time() - self.last_time) wps = words_since_update / (time.time() - self.last_time)
@ -170,19 +178,22 @@ def pretrain(texts_loc, vectors_model, output_dir, width=128, depth=4,
model = create_pretraining_model(nlp, tok2vec) model = create_pretraining_model(nlp, tok2vec)
optimizer = create_default_optimizer(model.ops) optimizer = create_default_optimizer(model.ops)
tracker = ProgressTracker() tracker = ProgressTracker()
texts = list(load_texts(texts_loc))
print('Epoch', '#Words', 'Loss', 'w/s') print('Epoch', '#Words', 'Loss', 'w/s')
texts = stream_texts() if text_loc == '-' else load_texts(texts_loc)
for epoch in range(nr_iter): for epoch in range(nr_iter):
random.shuffle(texts) for batch in minibatch_by_words(texts, tuples=False, size=50000):
for batch in minibatch(texts):
docs = [nlp.make_doc(text) for text in batch] docs = [nlp.make_doc(text) for text in batch]
loss = make_update(model, docs, optimizer, drop=dropout) loss = make_update(model, docs, optimizer, drop=dropout)
progress = tracker.update(epoch, loss, docs) progress = tracker.update(epoch, loss, docs)
if progress: if progress:
print(*progress) print(*progress)
if texts_loc == '-' and progress.words_per_epoch[epoch] >= 10**7:
break
with model.use_params(optimizer.averages): with model.use_params(optimizer.averages):
with (output_dir / ('model%d.bin' % epoch)).open('wb') as file_: with (output_dir / ('model%d.bin' % epoch)).open('wb') as file_:
file_.write(tok2vec.to_bytes()) file_.write(tok2vec.to_bytes())
with (output_dir / 'log.jsonl').open('a') as file_: with (output_dir / 'log.jsonl').open('a') as file_:
file_.write(json.dumps({'nr_word': tracker.nr_word, file_.write(json.dumps({'nr_word': tracker.nr_word,
'loss': tracker.loss, 'epoch': epoch})) 'loss': tracker.loss, 'epoch': epoch}))
if texts_loc != '-':
texts = load_texts(texts_loc)

View File

@ -465,7 +465,7 @@ def decaying(start, stop, decay):
nr_upd += 1 nr_upd += 1
def minibatch_by_words(items, size, count_words=len): def minibatch_by_words(items, size, tuples=True, count_words=len):
'''Create minibatches of a given number of words.''' '''Create minibatches of a given number of words.'''
if isinstance(size, int): if isinstance(size, int):
size_ = itertools.repeat(size) size_ = itertools.repeat(size)
@ -477,13 +477,19 @@ def minibatch_by_words(items, size, count_words=len):
batch = [] batch = []
while batch_size >= 0: while batch_size >= 0:
try: try:
doc, gold = next(items) if tuples:
doc, gold = next(items)
else:
doc = next(items)
except StopIteration: except StopIteration:
if batch: if batch:
yield batch yield batch
return return
batch_size -= count_words(doc) batch_size -= count_words(doc)
batch.append((doc, gold)) if tuples:
batch.append((doc, gold))
else:
batch.append(doc)
if batch: if batch:
yield batch yield batch