Fix PhraseMatcher example

This commit is contained in:
Matthew Honnibal 2017-09-20 22:51:41 +02:00
parent 0c93c73e49
commit 01858e9b59

View File

@ -20,72 +20,72 @@ The algorithm is O(n) at run-time for document of length n because we're only ev
matching over the tag patterns. So no matter how many phrases we're looking for, matching over the tag patterns. So no matter how many phrases we're looking for,
our pattern set stays very small (exact size depends on the maximum length we're our pattern set stays very small (exact size depends on the maximum length we're
looking for, as the query language currently has no quantifiers) looking for, as the query language currently has no quantifiers)
The example expects a .bz2 file from the Reddit corpus, and a patterns file,
formatted in jsonl as a sequence of entries like this:
{"text":"Anchorage"}
{"text":"Angola"}
{"text":"Ann Arbor"}
{"text":"Annapolis"}
{"text":"Appalachia"}
{"text":"Argentina"}
""" """
from __future__ import print_function, unicode_literals, division from __future__ import print_function, unicode_literals, division
from ast import literal_eval
from bz2 import BZ2File from bz2 import BZ2File
import time import time
import math import math
import codecs import codecs
import plac import plac
import ujson
from preshed.maps import PreshMap
from preshed.counter import PreshCounter
from spacy.strings import hash_string
from spacy.en import English
from spacy.matcher import PhraseMatcher from spacy.matcher import PhraseMatcher
import spacy
def read_gazetteer(tokenizer, loc, n=-1): def read_gazetteer(tokenizer, loc, n=-1):
for i, line in enumerate(open(loc)): for i, line in enumerate(open(loc)):
phrase = literal_eval('u' + line.strip()) data = ujson.loads(line.strip())
if ' (' in phrase and phrase.endswith(')'): phrase = tokenizer(data['text'])
phrase = phrase.split(' (', 1)[0] for w in phrase:
if i >= n: _ = tokenizer.vocab[w.text]
break
phrase = tokenizer(phrase)
if all((t.is_lower and t.prob >= -10) for t in phrase):
continue
if len(phrase) >= 2: if len(phrase) >= 2:
yield phrase yield phrase
def read_text(bz2_loc): def read_text(bz2_loc, n=10000):
with BZ2File(bz2_loc) as file_: with BZ2File(bz2_loc) as file_:
for line in file_: for i, line in enumerate(file_):
yield line.decode('utf8') data = ujson.loads(line)
yield data['body']
if i >= n:
break
def get_matches(tokenizer, phrases, texts, max_length=6): def get_matches(tokenizer, phrases, texts, max_length=6):
matcher = PhraseMatcher(tokenizer.vocab, phrases, max_length=max_length) matcher = PhraseMatcher(tokenizer.vocab, max_length=max_length)
print("Match") matcher.add('Phrase', None, *phrases)
for text in texts: for text in texts:
doc = tokenizer(text) doc = tokenizer(text)
for w in doc:
_ = doc.vocab[w.text]
matches = matcher(doc) matches = matcher(doc)
for mwe in doc.ents: for ent_id, start, end in matches:
yield mwe yield (ent_id, doc[start:end].text)
def main(patterns_loc, text_loc, counts_loc, n=10000000): def main(patterns_loc, text_loc, n=10000):
nlp = English(parser=False, tagger=False, entity=False) nlp = spacy.blank('en')
print("Make matcher") nlp.vocab.lex_attr_getters = {}
phrases = read_gazetteer(nlp.tokenizer, patterns_loc, n=n) phrases = read_gazetteer(nlp.tokenizer, patterns_loc)
counts = PreshCounter() count = 0
t1 = time.time() t1 = time.time()
for mwe in get_matches(nlp.tokenizer, phrases, read_text(text_loc)): for ent_id, text in get_matches(nlp.tokenizer, phrases, read_text(text_loc, n=n)):
counts.inc(hash_string(mwe.text), 1) count += 1
t2 = time.time() t2 = time.time()
print("10m tokens in %d s" % (t2 - t1)) print("%d docs in %.3f s. %d matches" % (n, (t2 - t1), count))
with codecs.open(counts_loc, 'w', 'utf8') as file_:
for phrase in read_gazetteer(nlp.tokenizer, patterns_loc, n=n):
text = phrase.string
key = hash_string(text)
count = counts[key]
if count != 0:
file_.write('%d\t%s\n' % (count, text))
if __name__ == '__main__': if __name__ == '__main__':
if False: if False: