spaCy/examples/information_extraction/phrase_matcher.py

108 lines
3.6 KiB
Python
Raw Normal View History

2017-11-01 02:43:22 +03:00
#!/usr/bin/env python
# coding: utf8
2015-10-06 01:06:52 +03:00
"""Match a large set of multi-word expressions in O(1) time.
The idea is to associate each word in the vocabulary with a tag, noting whether
they begin, end, or are inside at least one pattern. An additional tag is used
for single-word patterns. Complete patterns are also stored in a hash set.
2017-10-26 18:32:59 +03:00
When we process a document, we look up the words in the vocabulary, to
associate the words with the tags. We then search for tag-sequences that
correspond to valid candidates. Finally, we look up the candidates in the hash
set.
2015-10-06 01:06:52 +03:00
2017-10-26 18:32:59 +03:00
For instance, to search for the phrases "Barack Hussein Obama" and "Hilary
Clinton", we would associate "Barack" and "Hilary" with the B tag, Hussein with
the I tag, and Obama and Clinton with the L tag.
2015-10-06 01:06:52 +03:00
The document "Barack Clinton and Hilary Clinton" would have the tag sequence
2017-10-26 18:32:59 +03:00
[{B}, {L}, {}, {B}, {L}], so we'd get two matches. However, only the second
candidate is in the phrase dictionary, so only one is returned as a match.
2015-10-06 01:06:52 +03:00
2017-10-26 18:32:59 +03:00
The algorithm is O(n) at run-time for document of length n because we're only
ever matching over the tag patterns. So no matter how many phrases we're
looking for, our pattern set stays very small (exact size depends on the
maximum length we're looking for, as the query language currently has no
quantifiers).
2017-09-20 23:51:41 +03:00
The example expects a .bz2 file from the Reddit corpus, and a patterns file,
formatted in jsonl as a sequence of entries like this:
{"text":"Anchorage"}
{"text":"Angola"}
{"text":"Ann Arbor"}
{"text":"Annapolis"}
{"text":"Appalachia"}
{"text":"Argentina"}
2017-11-07 03:22:30 +03:00
Compatible with: spaCy v2.0.0+
2015-10-06 01:06:52 +03:00
"""
from __future__ import print_function, unicode_literals, division
2017-10-26 18:32:59 +03:00
2015-10-08 05:59:32 +03:00
from bz2 import BZ2File
import time
2015-10-06 01:06:52 +03:00
import plac
2017-09-20 23:51:41 +03:00
import ujson
2015-10-06 01:06:52 +03:00
2015-10-08 18:02:37 +03:00
from spacy.matcher import PhraseMatcher
2017-09-20 23:51:41 +03:00
import spacy
2015-10-08 18:02:37 +03:00
2017-10-26 18:32:59 +03:00
@plac.annotations(
patterns_loc=("Path to gazetteer", "positional", None, str),
text_loc=("Path to Reddit corpus file", "positional", None, str),
n=("Number of texts to read", "option", "n", int),
lang=("Language class to initialise", "option", "l", str))
def main(patterns_loc, text_loc, n=10000, lang='en'):
nlp = spacy.blank('en')
nlp.vocab.lex_attr_getters = {}
phrases = read_gazetteer(nlp.tokenizer, patterns_loc)
count = 0
t1 = time.time()
for ent_id, text in get_matches(nlp.tokenizer, phrases,
read_text(text_loc, n=n)):
count += 1
t2 = time.time()
print("%d docs in %.3f s. %d matches" % (n, (t2 - t1), count))
2015-10-08 18:02:37 +03:00
def read_gazetteer(tokenizer, loc, n=-1):
for i, line in enumerate(open(loc)):
2017-09-20 23:51:41 +03:00
data = ujson.loads(line.strip())
phrase = tokenizer(data['text'])
for w in phrase:
_ = tokenizer.vocab[w.text]
2015-10-08 18:02:37 +03:00
if len(phrase) >= 2:
yield phrase
2015-10-08 05:59:32 +03:00
2017-09-20 23:51:41 +03:00
def read_text(bz2_loc, n=10000):
2015-10-08 05:59:32 +03:00
with BZ2File(bz2_loc) as file_:
2017-09-20 23:51:41 +03:00
for i, line in enumerate(file_):
data = ujson.loads(line)
yield data['body']
if i >= n:
break
2015-10-06 01:06:52 +03:00
2015-10-08 18:02:37 +03:00
def get_matches(tokenizer, phrases, texts, max_length=6):
2017-09-20 23:51:41 +03:00
matcher = PhraseMatcher(tokenizer.vocab, max_length=max_length)
matcher.add('Phrase', None, *phrases)
2015-10-08 18:02:37 +03:00
for text in texts:
doc = tokenizer(text)
2017-09-20 23:51:41 +03:00
for w in doc:
_ = doc.vocab[w.text]
2015-10-08 18:02:37 +03:00
matches = matcher(doc)
2017-09-20 23:51:41 +03:00
for ent_id, start, end in matches:
yield (ent_id, doc[start:end].text)
2015-10-08 18:02:37 +03:00
2015-10-06 01:06:52 +03:00
if __name__ == '__main__':
2015-10-08 18:02:37 +03:00
if False:
import cProfile
import pstats
cProfile.runctx("plac.call(main)", globals(), locals(), "Profile.prof")
s = pstats.Stats("Profile.prof")
s.strip_dirs().sort_stats("time").print_stats()
else:
plac.call(main)