mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
Make pretrain script work with stream from stdin
This commit is contained in:
parent
8fdb9bc278
commit
3e7b214e57
|
@ -20,10 +20,11 @@ import numpy
|
|||
import time
|
||||
import ujson as json
|
||||
from pathlib import Path
|
||||
import sys
|
||||
|
||||
import spacy
|
||||
from spacy.attrs import ID
|
||||
from spacy.util import minibatch, use_gpu, compounding, ensure_path
|
||||
from spacy.util import minibatch_by_words, use_gpu, compounding, ensure_path
|
||||
from spacy._ml import Tok2Vec, flatten, chain, zero_init, create_default_optimizer
|
||||
from thinc.v2v import Affine
|
||||
|
||||
|
@ -45,9 +46,13 @@ def load_texts(path):
|
|||
'''
|
||||
path = ensure_path(path)
|
||||
with path.open('r', encoding='utf8') as file_:
|
||||
for line in file_:
|
||||
data = json.loads(line)
|
||||
yield data['text']
|
||||
texts = [json.loads(line)['text'] for line in file_]
|
||||
random.shuffle(texts)
|
||||
return texts
|
||||
|
||||
def stream_texts():
|
||||
for line in sys.stdin:
|
||||
yield json.loads(line)['text']
|
||||
|
||||
|
||||
def make_update(model, docs, optimizer, drop=0.):
|
||||
|
@ -102,16 +107,19 @@ def create_pretraining_model(nlp, tok2vec):
|
|||
|
||||
|
||||
class ProgressTracker(object):
|
||||
def __init__(self, frequency=10000):
|
||||
def __init__(self, frequency=100000):
|
||||
self.loss = 0.
|
||||
self.nr_word = 0
|
||||
self.words_per_epoch = Counter()
|
||||
self.frequency = frequency
|
||||
self.last_time = time.time()
|
||||
self.last_update = 0
|
||||
|
||||
def update(self, epoch, loss, docs):
|
||||
self.loss += loss
|
||||
self.nr_word += sum(len(doc) for doc in docs)
|
||||
words_in_batch = sum(len(doc) for doc in docs)
|
||||
self.words_per_epoch[epoch] += words_in_batch
|
||||
self.nr_word += words_in_batch
|
||||
words_since_update = self.nr_word - self.last_update
|
||||
if words_since_update >= self.frequency:
|
||||
wps = words_since_update / (time.time() - self.last_time)
|
||||
|
@ -170,19 +178,22 @@ def pretrain(texts_loc, vectors_model, output_dir, width=128, depth=4,
|
|||
model = create_pretraining_model(nlp, tok2vec)
|
||||
optimizer = create_default_optimizer(model.ops)
|
||||
tracker = ProgressTracker()
|
||||
texts = list(load_texts(texts_loc))
|
||||
print('Epoch', '#Words', 'Loss', 'w/s')
|
||||
texts = stream_texts() if text_loc == '-' else load_texts(texts_loc)
|
||||
for epoch in range(nr_iter):
|
||||
random.shuffle(texts)
|
||||
for batch in minibatch(texts):
|
||||
for batch in minibatch_by_words(texts, tuples=False, size=50000):
|
||||
docs = [nlp.make_doc(text) for text in batch]
|
||||
loss = make_update(model, docs, optimizer, drop=dropout)
|
||||
progress = tracker.update(epoch, loss, docs)
|
||||
if progress:
|
||||
print(*progress)
|
||||
if texts_loc == '-' and progress.words_per_epoch[epoch] >= 10**7:
|
||||
break
|
||||
with model.use_params(optimizer.averages):
|
||||
with (output_dir / ('model%d.bin' % epoch)).open('wb') as file_:
|
||||
file_.write(tok2vec.to_bytes())
|
||||
with (output_dir / 'log.jsonl').open('a') as file_:
|
||||
file_.write(json.dumps({'nr_word': tracker.nr_word,
|
||||
'loss': tracker.loss, 'epoch': epoch}))
|
||||
if texts_loc != '-':
|
||||
texts = load_texts(texts_loc)
|
||||
|
|
|
@ -465,7 +465,7 @@ def decaying(start, stop, decay):
|
|||
nr_upd += 1
|
||||
|
||||
|
||||
def minibatch_by_words(items, size, count_words=len):
|
||||
def minibatch_by_words(items, size, tuples=True, count_words=len):
|
||||
'''Create minibatches of a given number of words.'''
|
||||
if isinstance(size, int):
|
||||
size_ = itertools.repeat(size)
|
||||
|
@ -477,13 +477,19 @@ def minibatch_by_words(items, size, count_words=len):
|
|||
batch = []
|
||||
while batch_size >= 0:
|
||||
try:
|
||||
doc, gold = next(items)
|
||||
if tuples:
|
||||
doc, gold = next(items)
|
||||
else:
|
||||
doc = next(items)
|
||||
except StopIteration:
|
||||
if batch:
|
||||
yield batch
|
||||
return
|
||||
batch_size -= count_words(doc)
|
||||
batch.append((doc, gold))
|
||||
if tuples:
|
||||
batch.append((doc, gold))
|
||||
else:
|
||||
batch.append(doc)
|
||||
if batch:
|
||||
yield batch
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user