mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 01:46:28 +03:00
Make pretrain script work with stream from stdin
This commit is contained in:
parent
8fdb9bc278
commit
3e7b214e57
|
@ -20,10 +20,11 @@ import numpy
|
||||||
import time
|
import time
|
||||||
import ujson as json
|
import ujson as json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
import sys
|
||||||
|
|
||||||
import spacy
|
import spacy
|
||||||
from spacy.attrs import ID
|
from spacy.attrs import ID
|
||||||
from spacy.util import minibatch, use_gpu, compounding, ensure_path
|
from spacy.util import minibatch_by_words, use_gpu, compounding, ensure_path
|
||||||
from spacy._ml import Tok2Vec, flatten, chain, zero_init, create_default_optimizer
|
from spacy._ml import Tok2Vec, flatten, chain, zero_init, create_default_optimizer
|
||||||
from thinc.v2v import Affine
|
from thinc.v2v import Affine
|
||||||
|
|
||||||
|
@ -45,9 +46,13 @@ def load_texts(path):
|
||||||
'''
|
'''
|
||||||
path = ensure_path(path)
|
path = ensure_path(path)
|
||||||
with path.open('r', encoding='utf8') as file_:
|
with path.open('r', encoding='utf8') as file_:
|
||||||
for line in file_:
|
texts = [json.loads(line)['text'] for line in file_]
|
||||||
data = json.loads(line)
|
random.shuffle(texts)
|
||||||
yield data['text']
|
return texts
|
||||||
|
|
||||||
|
def stream_texts():
|
||||||
|
for line in sys.stdin:
|
||||||
|
yield json.loads(line)['text']
|
||||||
|
|
||||||
|
|
||||||
def make_update(model, docs, optimizer, drop=0.):
|
def make_update(model, docs, optimizer, drop=0.):
|
||||||
|
@ -102,16 +107,19 @@ def create_pretraining_model(nlp, tok2vec):
|
||||||
|
|
||||||
|
|
||||||
class ProgressTracker(object):
|
class ProgressTracker(object):
|
||||||
def __init__(self, frequency=10000):
|
def __init__(self, frequency=100000):
|
||||||
self.loss = 0.
|
self.loss = 0.
|
||||||
self.nr_word = 0
|
self.nr_word = 0
|
||||||
|
self.words_per_epoch = Counter()
|
||||||
self.frequency = frequency
|
self.frequency = frequency
|
||||||
self.last_time = time.time()
|
self.last_time = time.time()
|
||||||
self.last_update = 0
|
self.last_update = 0
|
||||||
|
|
||||||
def update(self, epoch, loss, docs):
|
def update(self, epoch, loss, docs):
|
||||||
self.loss += loss
|
self.loss += loss
|
||||||
self.nr_word += sum(len(doc) for doc in docs)
|
words_in_batch = sum(len(doc) for doc in docs)
|
||||||
|
self.words_per_epoch[epoch] += words_in_batch
|
||||||
|
self.nr_word += words_in_batch
|
||||||
words_since_update = self.nr_word - self.last_update
|
words_since_update = self.nr_word - self.last_update
|
||||||
if words_since_update >= self.frequency:
|
if words_since_update >= self.frequency:
|
||||||
wps = words_since_update / (time.time() - self.last_time)
|
wps = words_since_update / (time.time() - self.last_time)
|
||||||
|
@ -170,19 +178,22 @@ def pretrain(texts_loc, vectors_model, output_dir, width=128, depth=4,
|
||||||
model = create_pretraining_model(nlp, tok2vec)
|
model = create_pretraining_model(nlp, tok2vec)
|
||||||
optimizer = create_default_optimizer(model.ops)
|
optimizer = create_default_optimizer(model.ops)
|
||||||
tracker = ProgressTracker()
|
tracker = ProgressTracker()
|
||||||
texts = list(load_texts(texts_loc))
|
|
||||||
print('Epoch', '#Words', 'Loss', 'w/s')
|
print('Epoch', '#Words', 'Loss', 'w/s')
|
||||||
|
texts = stream_texts() if text_loc == '-' else load_texts(texts_loc)
|
||||||
for epoch in range(nr_iter):
|
for epoch in range(nr_iter):
|
||||||
random.shuffle(texts)
|
for batch in minibatch_by_words(texts, tuples=False, size=50000):
|
||||||
for batch in minibatch(texts):
|
|
||||||
docs = [nlp.make_doc(text) for text in batch]
|
docs = [nlp.make_doc(text) for text in batch]
|
||||||
loss = make_update(model, docs, optimizer, drop=dropout)
|
loss = make_update(model, docs, optimizer, drop=dropout)
|
||||||
progress = tracker.update(epoch, loss, docs)
|
progress = tracker.update(epoch, loss, docs)
|
||||||
if progress:
|
if progress:
|
||||||
print(*progress)
|
print(*progress)
|
||||||
|
if texts_loc == '-' and progress.words_per_epoch[epoch] >= 10**7:
|
||||||
|
break
|
||||||
with model.use_params(optimizer.averages):
|
with model.use_params(optimizer.averages):
|
||||||
with (output_dir / ('model%d.bin' % epoch)).open('wb') as file_:
|
with (output_dir / ('model%d.bin' % epoch)).open('wb') as file_:
|
||||||
file_.write(tok2vec.to_bytes())
|
file_.write(tok2vec.to_bytes())
|
||||||
with (output_dir / 'log.jsonl').open('a') as file_:
|
with (output_dir / 'log.jsonl').open('a') as file_:
|
||||||
file_.write(json.dumps({'nr_word': tracker.nr_word,
|
file_.write(json.dumps({'nr_word': tracker.nr_word,
|
||||||
'loss': tracker.loss, 'epoch': epoch}))
|
'loss': tracker.loss, 'epoch': epoch}))
|
||||||
|
if texts_loc != '-':
|
||||||
|
texts = load_texts(texts_loc)
|
||||||
|
|
|
@ -465,7 +465,7 @@ def decaying(start, stop, decay):
|
||||||
nr_upd += 1
|
nr_upd += 1
|
||||||
|
|
||||||
|
|
||||||
def minibatch_by_words(items, size, count_words=len):
|
def minibatch_by_words(items, size, tuples=True, count_words=len):
|
||||||
'''Create minibatches of a given number of words.'''
|
'''Create minibatches of a given number of words.'''
|
||||||
if isinstance(size, int):
|
if isinstance(size, int):
|
||||||
size_ = itertools.repeat(size)
|
size_ = itertools.repeat(size)
|
||||||
|
@ -477,13 +477,19 @@ def minibatch_by_words(items, size, count_words=len):
|
||||||
batch = []
|
batch = []
|
||||||
while batch_size >= 0:
|
while batch_size >= 0:
|
||||||
try:
|
try:
|
||||||
doc, gold = next(items)
|
if tuples:
|
||||||
|
doc, gold = next(items)
|
||||||
|
else:
|
||||||
|
doc = next(items)
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
if batch:
|
if batch:
|
||||||
yield batch
|
yield batch
|
||||||
return
|
return
|
||||||
batch_size -= count_words(doc)
|
batch_size -= count_words(doc)
|
||||||
batch.append((doc, gold))
|
if tuples:
|
||||||
|
batch.append((doc, gold))
|
||||||
|
else:
|
||||||
|
batch.append(doc)
|
||||||
if batch:
|
if batch:
|
||||||
yield batch
|
yield batch
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user