Change vector training to work with latest gensim (fix #3749) (#3757)

This commit is contained in:
Paul O'Leary McCann 2019-06-16 20:24:06 +09:00 committed by Matthew Honnibal
parent d8573ee715
commit 3f52e12335

View File

@ -13,23 +13,21 @@ logger = logging.getLogger(__name__)
class Corpus(object): class Corpus(object):
def __init__(self, directory, min_freq=10): def __init__(self, directory, nlp):
self.directory = directory self.directory = directory
self.counts = PreshCounter() self.nlp = nlp
self.strings = {}
self.min_freq = min_freq
def count_doc(self, doc):
# Get counts for this document
for word in doc:
self.counts.inc(word.orth, 1)
return len(doc)
def __iter__(self): def __iter__(self):
for text_loc in iter_dir(self.directory): for text_loc in iter_dir(self.directory):
with text_loc.open("r", encoding="utf-8") as file_: with text_loc.open("r", encoding="utf-8") as file_:
text = file_.read() text = file_.read()
yield text
# This is to keep the input to the blank model (which doesn't
# sentencize) from being too long. It works particularly well with
# the output of [WikiExtractor](https://github.com/attardi/wikiextractor)
paragraphs = text.split('\n\n')
for par in paragraphs:
yield [word.orth_ for word in self.nlp(par)]
def iter_dir(loc): def iter_dir(loc):
@ -62,12 +60,15 @@ def main(
window=5, window=5,
size=128, size=128,
min_count=10, min_count=10,
nr_iter=2, nr_iter=5,
): ):
logging.basicConfig( logging.basicConfig(
format="%(asctime)s : %(levelname)s : %(message)s", level=logging.INFO format="%(asctime)s : %(levelname)s : %(message)s", level=logging.INFO
) )
nlp = spacy.blank(lang)
corpus = Corpus(in_dir, nlp)
model = Word2Vec( model = Word2Vec(
sentences=corpus,
size=size, size=size,
window=window, window=window,
min_count=min_count, min_count=min_count,
@ -75,33 +76,7 @@ def main(
sample=1e-5, sample=1e-5,
negative=negative, negative=negative,
) )
nlp = spacy.blank(lang)
corpus = Corpus(in_dir)
total_words = 0
total_sents = 0
for text_no, text_loc in enumerate(iter_dir(corpus.directory)):
with text_loc.open("r", encoding="utf-8") as file_:
text = file_.read()
total_sents += text.count("\n")
doc = nlp(text)
total_words += corpus.count_doc(doc)
logger.info(
"PROGRESS: at batch #%i, processed %i words, keeping %i word types",
text_no,
total_words,
len(corpus.strings),
)
model.corpus_count = total_sents
model.raw_vocab = defaultdict(int)
for orth, freq in corpus.counts:
if freq >= min_count:
model.raw_vocab[nlp.vocab.strings[orth]] = freq
model.scale_vocab()
model.finalize_vocab()
model.iter = nr_iter
model.train(corpus)
model.save(out_loc) model.save(out_loc)
if __name__ == "__main__": if __name__ == "__main__":
plac.call(main) plac.call(main)