spaCy/bin/train_word_vectors.py
2019-07-11 13:05:53 +02:00

82 lines
2.2 KiB
Python

#!/usr/bin/env python
from __future__ import print_function, unicode_literals, division
import logging
from pathlib import Path
from collections import defaultdict
from gensim.models import Word2Vec
import plac
import spacy
logger = logging.getLogger(__name__)
class Corpus(object):
def __init__(self, directory, nlp):
self.directory = directory
self.nlp = nlp
def __iter__(self):
for text_loc in iter_dir(self.directory):
with text_loc.open("r", encoding="utf-8") as file_:
text = file_.read()
# This is to keep the input to the blank model (which doesn't
# sentencize) from being too long. It works particularly well with
# the output of [WikiExtractor](https://github.com/attardi/wikiextractor)
paragraphs = text.split('\n\n')
for par in paragraphs:
yield [word.orth_ for word in self.nlp(par)]
def iter_dir(loc):
dir_path = Path(loc)
for fn_path in dir_path.iterdir():
if fn_path.is_dir():
for sub_path in fn_path.iterdir():
yield sub_path
else:
yield fn_path
@plac.annotations(
lang=("ISO language code"),
in_dir=("Location of input directory"),
out_loc=("Location of output file"),
n_workers=("Number of workers", "option", "n", int),
size=("Dimension of the word vectors", "option", "d", int),
window=("Context window size", "option", "w", int),
min_count=("Min count", "option", "m", int),
negative=("Number of negative samples", "option", "g", int),
nr_iter=("Number of iterations", "option", "i", int),
)
def main(
lang,
in_dir,
out_loc,
negative=5,
n_workers=4,
window=5,
size=128,
min_count=10,
nr_iter=5,
):
logging.basicConfig(
format="%(asctime)s : %(levelname)s : %(message)s", level=logging.INFO
)
nlp = spacy.blank(lang)
corpus = Corpus(in_dir, nlp)
model = Word2Vec(
sentences=corpus,
size=size,
window=window,
min_count=min_count,
workers=n_workers,
sample=1e-5,
negative=negative,
)
model.save(out_loc)
if __name__ == "__main__":
plac.call(main)