Fix PhraseMatcher example

This commit is contained in:
Matthew Honnibal 2017-09-20 22:51:41 +02:00
parent 0c93c73e49
commit 01858e9b59

View File

@ -20,72 +20,72 @@ The algorithm is O(n) at run-time for document of length n because we're only ev
matching over the tag patterns. So no matter how many phrases we're looking for,
our pattern set stays very small (exact size depends on the maximum length we're
looking for, as the query language currently has no quantifiers)
The example expects a .bz2 file from the Reddit corpus, and a patterns file,
formatted in jsonl as a sequence of entries like this:
{"text":"Anchorage"}
{"text":"Angola"}
{"text":"Ann Arbor"}
{"text":"Annapolis"}
{"text":"Appalachia"}
{"text":"Argentina"}
"""
from __future__ import print_function, unicode_literals, division
from ast import literal_eval
from bz2 import BZ2File
import time
import math
import codecs
import plac
import ujson
from preshed.maps import PreshMap
from preshed.counter import PreshCounter
from spacy.strings import hash_string
from spacy.en import English
from spacy.matcher import PhraseMatcher
import spacy
def read_gazetteer(tokenizer, loc, n=-1):
for i, line in enumerate(open(loc)):
phrase = literal_eval('u' + line.strip())
if ' (' in phrase and phrase.endswith(')'):
phrase = phrase.split(' (', 1)[0]
if i >= n:
break
phrase = tokenizer(phrase)
if all((t.is_lower and t.prob >= -10) for t in phrase):
continue
data = ujson.loads(line.strip())
phrase = tokenizer(data['text'])
for w in phrase:
_ = tokenizer.vocab[w.text]
if len(phrase) >= 2:
yield phrase
def read_text(bz2_loc):
def read_text(bz2_loc, n=10000):
with BZ2File(bz2_loc) as file_:
for line in file_:
yield line.decode('utf8')
for i, line in enumerate(file_):
data = ujson.loads(line)
yield data['body']
if i >= n:
break
def get_matches(tokenizer, phrases, texts, max_length=6):
matcher = PhraseMatcher(tokenizer.vocab, phrases, max_length=max_length)
print("Match")
matcher = PhraseMatcher(tokenizer.vocab, max_length=max_length)
matcher.add('Phrase', None, *phrases)
for text in texts:
doc = tokenizer(text)
for w in doc:
_ = doc.vocab[w.text]
matches = matcher(doc)
for mwe in doc.ents:
yield mwe
for ent_id, start, end in matches:
yield (ent_id, doc[start:end].text)
def main(patterns_loc, text_loc, counts_loc, n=10000000):
nlp = English(parser=False, tagger=False, entity=False)
print("Make matcher")
phrases = read_gazetteer(nlp.tokenizer, patterns_loc, n=n)
counts = PreshCounter()
def main(patterns_loc, text_loc, n=10000):
nlp = spacy.blank('en')
nlp.vocab.lex_attr_getters = {}
phrases = read_gazetteer(nlp.tokenizer, patterns_loc)
count = 0
t1 = time.time()
for mwe in get_matches(nlp.tokenizer, phrases, read_text(text_loc)):
counts.inc(hash_string(mwe.text), 1)
for ent_id, text in get_matches(nlp.tokenizer, phrases, read_text(text_loc, n=n)):
count += 1
t2 = time.time()
print("10m tokens in %d s" % (t2 - t1))
with codecs.open(counts_loc, 'w', 'utf8') as file_:
for phrase in read_gazetteer(nlp.tokenizer, patterns_loc, n=n):
text = phrase.string
key = hash_string(text)
count = counts[key]
if count != 0:
file_.write('%d\t%s\n' % (count, text))
print("%d docs in %.3f s. %d matches" % (n, (t2 - t1), count))
if __name__ == '__main__':
if False: