spaCy/bin/get_freqs.py

93 lines
2.7 KiB
Python
Raw Normal View History

2015-07-14 03:31:32 +03:00
#!/usr/bin/env python
2015-10-15 20:20:35 +03:00
from __future__ import unicode_literals, print_function
2015-07-14 03:31:32 +03:00
import plac
import joblib
from os import path
import os
import bz2
import ujson
import codecs
from preshed.counter import PreshCounter
from joblib import Parallel, delayed
import spacy.en
from spacy.strings import StringStore
2015-10-15 20:20:35 +03:00
from spacy.attrs import ORTH
2015-10-15 20:24:08 +03:00
from spacy.tokenizer import Tokenizer
from spacy.vocab import Vocab
2015-07-14 03:31:32 +03:00
def iter_comments(loc):
with bz2.BZ2File(loc) as file_:
for line in file_:
yield ujson.loads(line)
def count_freqs(input_loc, output_loc):
2015-10-15 20:20:35 +03:00
print(output_loc)
2015-10-15 20:31:15 +03:00
tokenizer = Tokenizer.from_dir(Vocab(),
path.join(spacy.en.English.default_data_dir(), 'tokenizer'))
2015-07-14 03:31:32 +03:00
counts = PreshCounter()
for json_comment in iter_comments(input_loc):
doc = tokenizer(json_comment['body'])
doc.count_by(ORTH, counts=counts)
with codecs.open(output_loc, 'w', 'utf8') as file_:
for orth, freq in counts:
2015-10-15 20:20:35 +03:00
string = tokenizer.vocab.strings[orth]
if not string.isspace():
file_.write('%d\t%s\n' % (freq, string))
2015-07-14 03:31:32 +03:00
def parallelize(func, iterator, n_jobs):
Parallel(n_jobs=n_jobs)(delayed(func)(*item) for item in iterator)
def merge_counts(locs, out_loc):
string_map = StringStore()
counts = PreshCounter()
for loc in locs:
2015-10-15 20:20:35 +03:00
with io.open(loc, 'r', encoding='utf8') as file_:
2015-07-14 03:31:32 +03:00
for line in file_:
freq, word = line.strip().split('\t', 1)
orth = string_map[word]
counts.inc(orth, int(freq))
2015-10-15 20:20:35 +03:00
with io.open(out_loc, 'w', encoding='utf8') as file_:
2015-07-22 16:43:06 +03:00
for orth, count in counts:
2015-07-14 03:31:32 +03:00
string = string_map[orth]
file_.write('%d\t%s\n' % (count, string))
@plac.annotations(
input_loc=("Location of input file list"),
2015-07-14 03:31:32 +03:00
freqs_dir=("Directory for frequency files"),
output_loc=("Location for output file"),
n_jobs=("Number of workers", "option", "n", int),
2015-07-22 16:43:06 +03:00
skip_existing=("Skip inputs where an output file exists", "flag", "s", bool),
2015-07-14 03:31:32 +03:00
)
2015-07-22 16:43:06 +03:00
def main(input_loc, freqs_dir, output_loc, n_jobs=2, skip_existing=False):
2015-07-14 03:31:32 +03:00
tasks = []
2015-07-22 16:43:06 +03:00
outputs = []
for input_path in open(input_loc):
input_path = input_path.strip()
2015-07-22 16:43:06 +03:00
if not input_path:
continue
filename = input_path.split('/')[-1]
2015-07-14 03:31:32 +03:00
output_path = path.join(freqs_dir, filename.replace('bz2', 'freq'))
2015-07-22 16:43:06 +03:00
outputs.append(output_path)
if not path.exists(output_path) or not skip_existing:
tasks.append((input_path, output_path))
2015-07-14 03:31:32 +03:00
2015-07-25 22:13:41 +03:00
if tasks:
parallelize(count_freqs, tasks, n_jobs)
2015-07-14 03:31:32 +03:00
2015-10-15 20:20:35 +03:00
print("Merge")
2015-07-22 16:43:06 +03:00
merge_counts(outputs, output_loc)
2015-07-14 03:31:32 +03:00
if __name__ == '__main__':
plac.call(main)