mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-10 19:57:17 +03:00
Improve train tensorizer script
This commit is contained in:
parent
ba365ae1c9
commit
0127f10ba3
|
@ -2,7 +2,7 @@
|
||||||
import plac
|
import plac
|
||||||
import spacy
|
import spacy
|
||||||
import thinc.extra.datasets
|
import thinc.extra.datasets
|
||||||
from spacy.util import minibatch
|
from spacy.util import minibatch, use_gpu
|
||||||
import tqdm
|
import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
@ -12,7 +12,7 @@ def load_imdb():
|
||||||
train_texts, _ = zip(*train)
|
train_texts, _ = zip(*train)
|
||||||
dev_texts, _ = zip(*dev)
|
dev_texts, _ = zip(*dev)
|
||||||
nlp.add_pipe(nlp.create_pipe('sentencizer'))
|
nlp.add_pipe(nlp.create_pipe('sentencizer'))
|
||||||
return list(get_sentences(nlp, train_texts)), list(get_sentences(nlp, dev_texts))
|
return list(train_texts), list(dev_texts)
|
||||||
|
|
||||||
|
|
||||||
def get_sentences(nlp, texts):
|
def get_sentences(nlp, texts):
|
||||||
|
@ -21,12 +21,20 @@ def get_sentences(nlp, texts):
|
||||||
yield sent.text
|
yield sent.text
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def prefer_gpu():
|
||||||
|
used = spacy.util.use_gpu(0)
|
||||||
|
if used is None:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def main(vectors_model):
|
||||||
|
use_gpu = prefer_gpu()
|
||||||
|
print("Using GPU?", use_gpu)
|
||||||
print("Load data")
|
print("Load data")
|
||||||
train_texts, dev_texts = load_imdb()
|
train_texts, dev_texts = load_imdb()
|
||||||
train_texts = train_texts[:1000]
|
|
||||||
print("Load vectors")
|
print("Load vectors")
|
||||||
nlp = spacy.load('en_vectors_web_lg')
|
nlp = spacy.load(vectors_model)
|
||||||
print("Start training")
|
print("Start training")
|
||||||
nlp.add_pipe(nlp.create_pipe('tagger'))
|
nlp.add_pipe(nlp.create_pipe('tagger'))
|
||||||
tensorizer = nlp.create_pipe('tensorizer')
|
tensorizer = nlp.create_pipe('tensorizer')
|
||||||
|
@ -38,8 +46,7 @@ def main():
|
||||||
for i, batch in enumerate(minibatch(tqdm.tqdm(train_texts))):
|
for i, batch in enumerate(minibatch(tqdm.tqdm(train_texts))):
|
||||||
docs = [nlp.make_doc(text) for text in batch]
|
docs = [nlp.make_doc(text) for text in batch]
|
||||||
tensorizer.update(docs, None, losses=losses, sgd=optimizer, drop=0.5)
|
tensorizer.update(docs, None, losses=losses, sgd=optimizer, drop=0.5)
|
||||||
if i % 10 == 0:
|
print(losses)
|
||||||
print(losses)
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
plac.call(main)
|
plac.call(main)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user