diff --git a/bin/train_word_vectors.py b/bin/train_word_vectors.py index 8482a7a55..624e339a0 100644 --- a/bin/train_word_vectors.py +++ b/bin/train_word_vectors.py @@ -13,23 +13,21 @@ logger = logging.getLogger(__name__) class Corpus(object): - def __init__(self, directory, min_freq=10): + def __init__(self, directory, nlp): self.directory = directory - self.counts = PreshCounter() - self.strings = {} - self.min_freq = min_freq - - def count_doc(self, doc): - # Get counts for this document - for word in doc: - self.counts.inc(word.orth, 1) - return len(doc) + self.nlp = nlp def __iter__(self): for text_loc in iter_dir(self.directory): with text_loc.open("r", encoding="utf-8") as file_: text = file_.read() - yield text + + # This is to keep the input to the blank model (which doesn't + # sentencize) from being too long. It works particularly well with + # the output of [WikiExtractor](https://github.com/attardi/wikiextractor) + paragraphs = text.split('\n\n') + for par in paragraphs: + yield [word.orth_ for word in self.nlp(par)] def iter_dir(loc): @@ -62,12 +60,15 @@ def main( window=5, size=128, min_count=10, - nr_iter=2, + nr_iter=5, ): logging.basicConfig( format="%(asctime)s : %(levelname)s : %(message)s", level=logging.INFO ) + nlp = spacy.blank(lang) + corpus = Corpus(in_dir, nlp) model = Word2Vec( + sentences=corpus, size=size, window=window, min_count=min_count, @@ -75,33 +76,7 @@ def main( sample=1e-5, negative=negative, ) - nlp = spacy.blank(lang) - corpus = Corpus(in_dir) - total_words = 0 - total_sents = 0 - for text_no, text_loc in enumerate(iter_dir(corpus.directory)): - with text_loc.open("r", encoding="utf-8") as file_: - text = file_.read() - total_sents += text.count("\n") - doc = nlp(text) - total_words += corpus.count_doc(doc) - logger.info( - "PROGRESS: at batch #%i, processed %i words, keeping %i word types", - text_no, - total_words, - len(corpus.strings), - ) - model.corpus_count = total_sents - model.raw_vocab = defaultdict(int) - for orth, freq in corpus.counts: - if freq >= min_count: - model.raw_vocab[nlp.vocab.strings[orth]] = freq - model.scale_vocab() - model.finalize_vocab() - model.iter = nr_iter - model.train(corpus) model.save(out_loc) - if __name__ == "__main__": plac.call(main)