mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 09:57:26 +03:00 
			
		
		
		
	Make pretrain script work with stream from stdin
This commit is contained in:
		
							parent
							
								
									8fdb9bc278
								
							
						
					
					
						commit
						3e7b214e57
					
				| 
						 | 
					@ -20,10 +20,11 @@ import numpy
 | 
				
			||||||
import time
 | 
					import time
 | 
				
			||||||
import ujson as json
 | 
					import ujson as json
 | 
				
			||||||
from pathlib import Path
 | 
					from pathlib import Path
 | 
				
			||||||
 | 
					import sys
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import spacy
 | 
					import spacy
 | 
				
			||||||
from spacy.attrs import ID
 | 
					from spacy.attrs import ID
 | 
				
			||||||
from spacy.util import minibatch, use_gpu, compounding, ensure_path
 | 
					from spacy.util import minibatch_by_words, use_gpu, compounding, ensure_path
 | 
				
			||||||
from spacy._ml import Tok2Vec, flatten, chain, zero_init, create_default_optimizer
 | 
					from spacy._ml import Tok2Vec, flatten, chain, zero_init, create_default_optimizer
 | 
				
			||||||
from thinc.v2v import Affine
 | 
					from thinc.v2v import Affine
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -45,9 +46,13 @@ def load_texts(path):
 | 
				
			||||||
    '''
 | 
					    '''
 | 
				
			||||||
    path = ensure_path(path)
 | 
					    path = ensure_path(path)
 | 
				
			||||||
    with path.open('r', encoding='utf8') as file_:
 | 
					    with path.open('r', encoding='utf8') as file_:
 | 
				
			||||||
        for line in file_:
 | 
					        texts = [json.loads(line)['text'] for line in file_]
 | 
				
			||||||
            data = json.loads(line)
 | 
					    random.shuffle(texts)
 | 
				
			||||||
            yield data['text']
 | 
					    return texts
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def stream_texts():
 | 
				
			||||||
 | 
					    for line in sys.stdin:
 | 
				
			||||||
 | 
					        yield json.loads(line)['text']
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def make_update(model, docs, optimizer, drop=0.):
 | 
					def make_update(model, docs, optimizer, drop=0.):
 | 
				
			||||||
| 
						 | 
					@ -102,16 +107,19 @@ def create_pretraining_model(nlp, tok2vec):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class ProgressTracker(object):
 | 
					class ProgressTracker(object):
 | 
				
			||||||
    def __init__(self, frequency=10000):
 | 
					    def __init__(self, frequency=100000):
 | 
				
			||||||
        self.loss = 0.
 | 
					        self.loss = 0.
 | 
				
			||||||
        self.nr_word = 0
 | 
					        self.nr_word = 0
 | 
				
			||||||
 | 
					        self.words_per_epoch = Counter()
 | 
				
			||||||
        self.frequency = frequency
 | 
					        self.frequency = frequency
 | 
				
			||||||
        self.last_time = time.time()
 | 
					        self.last_time = time.time()
 | 
				
			||||||
        self.last_update = 0
 | 
					        self.last_update = 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def update(self, epoch, loss, docs):
 | 
					    def update(self, epoch, loss, docs):
 | 
				
			||||||
        self.loss += loss
 | 
					        self.loss += loss
 | 
				
			||||||
        self.nr_word += sum(len(doc) for doc in docs)
 | 
					        words_in_batch = sum(len(doc) for doc in docs)
 | 
				
			||||||
 | 
					        self.words_per_epoch[epoch] += words_in_batch
 | 
				
			||||||
 | 
					        self.nr_word += words_in_batch
 | 
				
			||||||
        words_since_update = self.nr_word - self.last_update
 | 
					        words_since_update = self.nr_word - self.last_update
 | 
				
			||||||
        if words_since_update >= self.frequency:
 | 
					        if words_since_update >= self.frequency:
 | 
				
			||||||
            wps = words_since_update / (time.time() - self.last_time)
 | 
					            wps = words_since_update / (time.time() - self.last_time)
 | 
				
			||||||
| 
						 | 
					@ -170,19 +178,22 @@ def pretrain(texts_loc, vectors_model, output_dir, width=128, depth=4,
 | 
				
			||||||
    model = create_pretraining_model(nlp, tok2vec)
 | 
					    model = create_pretraining_model(nlp, tok2vec)
 | 
				
			||||||
    optimizer = create_default_optimizer(model.ops)
 | 
					    optimizer = create_default_optimizer(model.ops)
 | 
				
			||||||
    tracker = ProgressTracker()
 | 
					    tracker = ProgressTracker()
 | 
				
			||||||
    texts = list(load_texts(texts_loc))
 | 
					 | 
				
			||||||
    print('Epoch', '#Words', 'Loss', 'w/s')
 | 
					    print('Epoch', '#Words', 'Loss', 'w/s')
 | 
				
			||||||
 | 
					    texts = stream_texts() if text_loc == '-' else load_texts(texts_loc) 
 | 
				
			||||||
    for epoch in range(nr_iter):
 | 
					    for epoch in range(nr_iter):
 | 
				
			||||||
        random.shuffle(texts)
 | 
					        for batch in minibatch_by_words(texts, tuples=False, size=50000):
 | 
				
			||||||
        for batch in minibatch(texts):
 | 
					 | 
				
			||||||
            docs = [nlp.make_doc(text) for text in batch]
 | 
					            docs = [nlp.make_doc(text) for text in batch]
 | 
				
			||||||
            loss = make_update(model, docs, optimizer, drop=dropout)
 | 
					            loss = make_update(model, docs, optimizer, drop=dropout)
 | 
				
			||||||
            progress = tracker.update(epoch, loss, docs)
 | 
					            progress = tracker.update(epoch, loss, docs)
 | 
				
			||||||
            if progress:
 | 
					            if progress:
 | 
				
			||||||
                print(*progress)
 | 
					                print(*progress)
 | 
				
			||||||
 | 
					            if texts_loc == '-' and progress.words_per_epoch[epoch] >= 10**7:
 | 
				
			||||||
 | 
					                break
 | 
				
			||||||
        with model.use_params(optimizer.averages):
 | 
					        with model.use_params(optimizer.averages):
 | 
				
			||||||
            with (output_dir / ('model%d.bin' % epoch)).open('wb') as file_:
 | 
					            with (output_dir / ('model%d.bin' % epoch)).open('wb') as file_:
 | 
				
			||||||
                file_.write(tok2vec.to_bytes())
 | 
					                file_.write(tok2vec.to_bytes())
 | 
				
			||||||
            with (output_dir / 'log.jsonl').open('a') as file_:
 | 
					            with (output_dir / 'log.jsonl').open('a') as file_:
 | 
				
			||||||
                file_.write(json.dumps({'nr_word': tracker.nr_word,
 | 
					                file_.write(json.dumps({'nr_word': tracker.nr_word,
 | 
				
			||||||
                    'loss': tracker.loss, 'epoch': epoch}))
 | 
					                    'loss': tracker.loss, 'epoch': epoch}))
 | 
				
			||||||
 | 
					        if texts_loc != '-':
 | 
				
			||||||
 | 
					            texts = load_texts(texts_loc)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -465,7 +465,7 @@ def decaying(start, stop, decay):
 | 
				
			||||||
        nr_upd += 1
 | 
					        nr_upd += 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def minibatch_by_words(items, size, count_words=len):
 | 
					def minibatch_by_words(items, size, tuples=True, count_words=len):
 | 
				
			||||||
    '''Create minibatches of a given number of words.'''
 | 
					    '''Create minibatches of a given number of words.'''
 | 
				
			||||||
    if isinstance(size, int):
 | 
					    if isinstance(size, int):
 | 
				
			||||||
        size_ = itertools.repeat(size)
 | 
					        size_ = itertools.repeat(size)
 | 
				
			||||||
| 
						 | 
					@ -477,13 +477,19 @@ def minibatch_by_words(items, size, count_words=len):
 | 
				
			||||||
        batch = []
 | 
					        batch = []
 | 
				
			||||||
        while batch_size >= 0:
 | 
					        while batch_size >= 0:
 | 
				
			||||||
            try:
 | 
					            try:
 | 
				
			||||||
                doc, gold = next(items)
 | 
					                if tuples:
 | 
				
			||||||
 | 
					                    doc, gold = next(items)
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    doc = next(items)
 | 
				
			||||||
            except StopIteration:
 | 
					            except StopIteration:
 | 
				
			||||||
                if batch:
 | 
					                if batch:
 | 
				
			||||||
                    yield batch
 | 
					                    yield batch
 | 
				
			||||||
                return
 | 
					                return
 | 
				
			||||||
            batch_size -= count_words(doc)
 | 
					            batch_size -= count_words(doc)
 | 
				
			||||||
            batch.append((doc, gold))
 | 
					            if tuples:
 | 
				
			||||||
 | 
					                batch.append((doc, gold))
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                batch.append(doc)
 | 
				
			||||||
        if batch:
 | 
					        if batch:
 | 
				
			||||||
            yield batch
 | 
					            yield batch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user