Format example

This commit is contained in:
Ines Montani 2018-12-02 04:28:34 +01:00
parent 45798cc53e
commit 40b57ea4ac

View File

@ -55,15 +55,15 @@ import spacy
patterns_loc=("Path to gazetteer", "positional", None, str), patterns_loc=("Path to gazetteer", "positional", None, str),
text_loc=("Path to Reddit corpus file", "positional", None, str), text_loc=("Path to Reddit corpus file", "positional", None, str),
n=("Number of texts to read", "option", "n", int), n=("Number of texts to read", "option", "n", int),
lang=("Language class to initialise", "option", "l", str)) lang=("Language class to initialise", "option", "l", str),
def main(patterns_loc, text_loc, n=10000, lang='en'): )
nlp = spacy.blank('en') def main(patterns_loc, text_loc, n=10000, lang="en"):
nlp = spacy.blank("en")
nlp.vocab.lex_attr_getters = {} nlp.vocab.lex_attr_getters = {}
phrases = read_gazetteer(nlp.tokenizer, patterns_loc) phrases = read_gazetteer(nlp.tokenizer, patterns_loc)
count = 0 count = 0
t1 = time.time() t1 = time.time()
for ent_id, text in get_matches(nlp.tokenizer, phrases, for ent_id, text in get_matches(nlp.tokenizer, phrases, read_text(text_loc, n=n)):
read_text(text_loc, n=n)):
count += 1 count += 1
t2 = time.time() t2 = time.time()
print("%d docs in %.3f s. %d matches" % (n, (t2 - t1), count)) print("%d docs in %.3f s. %d matches" % (n, (t2 - t1), count))
@ -72,7 +72,7 @@ def main(patterns_loc, text_loc, n=10000, lang='en'):
def read_gazetteer(tokenizer, loc, n=-1): def read_gazetteer(tokenizer, loc, n=-1):
for i, line in enumerate(open(loc)): for i, line in enumerate(open(loc)):
data = ujson.loads(line.strip()) data = ujson.loads(line.strip())
phrase = tokenizer(data['text']) phrase = tokenizer(data["text"])
for w in phrase: for w in phrase:
_ = tokenizer.vocab[w.text] _ = tokenizer.vocab[w.text]
if len(phrase) >= 2: if len(phrase) >= 2:
@ -83,14 +83,14 @@ def read_text(bz2_loc, n=10000):
with BZ2File(bz2_loc) as file_: with BZ2File(bz2_loc) as file_:
for i, line in enumerate(file_): for i, line in enumerate(file_):
data = ujson.loads(line) data = ujson.loads(line)
yield data['body'] yield data["body"]
if i >= n: if i >= n:
break break
def get_matches(tokenizer, phrases, texts, max_length=6): def get_matches(tokenizer, phrases, texts, max_length=6):
matcher = PhraseMatcher(tokenizer.vocab, max_length=max_length) matcher = PhraseMatcher(tokenizer.vocab, max_length=max_length)
matcher.add('Phrase', None, *phrases) matcher.add("Phrase", None, *phrases)
for text in texts: for text in texts:
doc = tokenizer(text) doc = tokenizer(text)
for w in doc: for w in doc:
@ -100,10 +100,11 @@ def get_matches(tokenizer, phrases, texts, max_length=6):
yield (ent_id, doc[start:end].text) yield (ent_id, doc[start:end].text)
if __name__ == '__main__': if __name__ == "__main__":
if False: if False:
import cProfile import cProfile
import pstats import pstats
cProfile.runctx("plac.call(main)", globals(), locals(), "Profile.prof") cProfile.runctx("plac.call(main)", globals(), locals(), "Profile.prof")
s = pstats.Stats("Profile.prof") s = pstats.Stats("Profile.prof")
s.strip_dirs().sort_stats("time").print_stats() s.strip_dirs().sort_stats("time").print_stats()