mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 01:46:28 +03:00
parent
d8573ee715
commit
3f52e12335
|
@ -13,23 +13,21 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Corpus(object):
|
class Corpus(object):
|
||||||
def __init__(self, directory, min_freq=10):
|
def __init__(self, directory, nlp):
|
||||||
self.directory = directory
|
self.directory = directory
|
||||||
self.counts = PreshCounter()
|
self.nlp = nlp
|
||||||
self.strings = {}
|
|
||||||
self.min_freq = min_freq
|
|
||||||
|
|
||||||
def count_doc(self, doc):
|
|
||||||
# Get counts for this document
|
|
||||||
for word in doc:
|
|
||||||
self.counts.inc(word.orth, 1)
|
|
||||||
return len(doc)
|
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
for text_loc in iter_dir(self.directory):
|
for text_loc in iter_dir(self.directory):
|
||||||
with text_loc.open("r", encoding="utf-8") as file_:
|
with text_loc.open("r", encoding="utf-8") as file_:
|
||||||
text = file_.read()
|
text = file_.read()
|
||||||
yield text
|
|
||||||
|
# This is to keep the input to the blank model (which doesn't
|
||||||
|
# sentencize) from being too long. It works particularly well with
|
||||||
|
# the output of [WikiExtractor](https://github.com/attardi/wikiextractor)
|
||||||
|
paragraphs = text.split('\n\n')
|
||||||
|
for par in paragraphs:
|
||||||
|
yield [word.orth_ for word in self.nlp(par)]
|
||||||
|
|
||||||
|
|
||||||
def iter_dir(loc):
|
def iter_dir(loc):
|
||||||
|
@ -62,12 +60,15 @@ def main(
|
||||||
window=5,
|
window=5,
|
||||||
size=128,
|
size=128,
|
||||||
min_count=10,
|
min_count=10,
|
||||||
nr_iter=2,
|
nr_iter=5,
|
||||||
):
|
):
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
format="%(asctime)s : %(levelname)s : %(message)s", level=logging.INFO
|
format="%(asctime)s : %(levelname)s : %(message)s", level=logging.INFO
|
||||||
)
|
)
|
||||||
|
nlp = spacy.blank(lang)
|
||||||
|
corpus = Corpus(in_dir, nlp)
|
||||||
model = Word2Vec(
|
model = Word2Vec(
|
||||||
|
sentences=corpus,
|
||||||
size=size,
|
size=size,
|
||||||
window=window,
|
window=window,
|
||||||
min_count=min_count,
|
min_count=min_count,
|
||||||
|
@ -75,33 +76,7 @@ def main(
|
||||||
sample=1e-5,
|
sample=1e-5,
|
||||||
negative=negative,
|
negative=negative,
|
||||||
)
|
)
|
||||||
nlp = spacy.blank(lang)
|
|
||||||
corpus = Corpus(in_dir)
|
|
||||||
total_words = 0
|
|
||||||
total_sents = 0
|
|
||||||
for text_no, text_loc in enumerate(iter_dir(corpus.directory)):
|
|
||||||
with text_loc.open("r", encoding="utf-8") as file_:
|
|
||||||
text = file_.read()
|
|
||||||
total_sents += text.count("\n")
|
|
||||||
doc = nlp(text)
|
|
||||||
total_words += corpus.count_doc(doc)
|
|
||||||
logger.info(
|
|
||||||
"PROGRESS: at batch #%i, processed %i words, keeping %i word types",
|
|
||||||
text_no,
|
|
||||||
total_words,
|
|
||||||
len(corpus.strings),
|
|
||||||
)
|
|
||||||
model.corpus_count = total_sents
|
|
||||||
model.raw_vocab = defaultdict(int)
|
|
||||||
for orth, freq in corpus.counts:
|
|
||||||
if freq >= min_count:
|
|
||||||
model.raw_vocab[nlp.vocab.strings[orth]] = freq
|
|
||||||
model.scale_vocab()
|
|
||||||
model.finalize_vocab()
|
|
||||||
model.iter = nr_iter
|
|
||||||
model.train(corpus)
|
|
||||||
model.save(out_loc)
|
model.save(out_loc)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
plac.call(main)
|
plac.call(main)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user