Merge pull request #1343 from explosion/feature/phrasematcher

Update PhraseMatcher for spaCy 2
This commit is contained in:
Matthew Honnibal 2017-09-26 20:44:23 +02:00 committed by GitHub
commit 1ef4236f8e
3 changed files with 95 additions and 72 deletions

View File

@ -20,72 +20,72 @@ The algorithm is O(n) at run-time for document of length n because we're only ev
matching over the tag patterns. So no matter how many phrases we're looking for, matching over the tag patterns. So no matter how many phrases we're looking for,
our pattern set stays very small (exact size depends on the maximum length we're our pattern set stays very small (exact size depends on the maximum length we're
looking for, as the query language currently has no quantifiers) looking for, as the query language currently has no quantifiers)
The example expects a .bz2 file from the Reddit corpus, and a patterns file,
formatted in jsonl as a sequence of entries like this:
{"text":"Anchorage"}
{"text":"Angola"}
{"text":"Ann Arbor"}
{"text":"Annapolis"}
{"text":"Appalachia"}
{"text":"Argentina"}
""" """
from __future__ import print_function, unicode_literals, division from __future__ import print_function, unicode_literals, division
from ast import literal_eval
from bz2 import BZ2File from bz2 import BZ2File
import time import time
import math import math
import codecs import codecs
import plac import plac
import ujson
from preshed.maps import PreshMap
from preshed.counter import PreshCounter
from spacy.strings import hash_string
from spacy.en import English
from spacy.matcher import PhraseMatcher from spacy.matcher import PhraseMatcher
import spacy
def read_gazetteer(tokenizer, loc, n=-1): def read_gazetteer(tokenizer, loc, n=-1):
for i, line in enumerate(open(loc)): for i, line in enumerate(open(loc)):
phrase = literal_eval('u' + line.strip()) data = ujson.loads(line.strip())
if ' (' in phrase and phrase.endswith(')'): phrase = tokenizer(data['text'])
phrase = phrase.split(' (', 1)[0] for w in phrase:
if i >= n: _ = tokenizer.vocab[w.text]
break
phrase = tokenizer(phrase)
if all((t.is_lower and t.prob >= -10) for t in phrase):
continue
if len(phrase) >= 2: if len(phrase) >= 2:
yield phrase yield phrase
def read_text(bz2_loc): def read_text(bz2_loc, n=10000):
with BZ2File(bz2_loc) as file_: with BZ2File(bz2_loc) as file_:
for line in file_: for i, line in enumerate(file_):
yield line.decode('utf8') data = ujson.loads(line)
yield data['body']
if i >= n:
break
def get_matches(tokenizer, phrases, texts, max_length=6): def get_matches(tokenizer, phrases, texts, max_length=6):
matcher = PhraseMatcher(tokenizer.vocab, phrases, max_length=max_length) matcher = PhraseMatcher(tokenizer.vocab, max_length=max_length)
print("Match") matcher.add('Phrase', None, *phrases)
for text in texts: for text in texts:
doc = tokenizer(text) doc = tokenizer(text)
for w in doc:
_ = doc.vocab[w.text]
matches = matcher(doc) matches = matcher(doc)
for mwe in doc.ents: for ent_id, start, end in matches:
yield mwe yield (ent_id, doc[start:end].text)
def main(patterns_loc, text_loc, counts_loc, n=10000000): def main(patterns_loc, text_loc, n=10000):
nlp = English(parser=False, tagger=False, entity=False) nlp = spacy.blank('en')
print("Make matcher") nlp.vocab.lex_attr_getters = {}
phrases = read_gazetteer(nlp.tokenizer, patterns_loc, n=n) phrases = read_gazetteer(nlp.tokenizer, patterns_loc)
counts = PreshCounter() count = 0
t1 = time.time() t1 = time.time()
for mwe in get_matches(nlp.tokenizer, phrases, read_text(text_loc)): for ent_id, text in get_matches(nlp.tokenizer, phrases, read_text(text_loc, n=n)):
counts.inc(hash_string(mwe.text), 1) count += 1
t2 = time.time() t2 = time.time()
print("10m tokens in %d s" % (t2 - t1)) print("%d docs in %.3f s. %d matches" % (n, (t2 - t1), count))
with codecs.open(counts_loc, 'w', 'utf8') as file_:
for phrase in read_gazetteer(nlp.tokenizer, patterns_loc, n=n):
text = phrase.string
key = hash_string(text)
count = counts[key]
if count != 0:
file_.write('%d\t%s\n' % (count, text))
if __name__ == '__main__': if __name__ == '__main__':
if False: if False:

View File

@ -421,47 +421,69 @@ cdef class PhraseMatcher:
cdef int max_length cdef int max_length
cdef attr_t* _phrase_key cdef attr_t* _phrase_key
def __init__(self, Vocab vocab, phrases, max_length=10): cdef public object _callbacks
cdef public object _patterns
def __init__(self, Vocab vocab, max_length=10):
self.mem = Pool() self.mem = Pool()
self._phrase_key = <attr_t*>self.mem.alloc(max_length, sizeof(attr_t)) self._phrase_key = <attr_t*>self.mem.alloc(max_length, sizeof(attr_t))
self.max_length = max_length self.max_length = max_length
self.vocab = vocab self.vocab = vocab
self.matcher = Matcher(self.vocab, {}) self.matcher = Matcher(self.vocab)
self.phrase_ids = PreshMap() self.phrase_ids = PreshMap()
for phrase in phrases:
if len(phrase) < max_length:
self.add(phrase)
abstract_patterns = [] abstract_patterns = []
for length in range(1, max_length): for length in range(1, max_length):
abstract_patterns.append([{tag: True} for tag in get_bilou(length)]) abstract_patterns.append([{tag: True} for tag in get_bilou(length)])
self.matcher.add('Candidate', 'MWE', {}, abstract_patterns, acceptor=self.accept_match) self.matcher.add('Candidate', None, *abstract_patterns)
self._callbacks = {}
def add(self, Doc tokens): def __len__(self):
cdef int length = tokens.length raise NotImplementedError
assert length < self.max_length
tags = get_bilou(length)
assert len(tags) == length, length
def __contains__(self, key):
raise NotImplementedError
def __reduce__(self):
return (self.__class__, (self.vocab,), None, None)
def add(self, key, on_match, *docs):
cdef Doc doc
for doc in docs:
if len(doc) >= self.max_length:
msg = (
"Pattern length (%d) >= phrase_matcher.max_length (%d). "
"Length can be set on initialization, up to 10."
)
raise ValueError(msg % (len(doc), self.max_length))
cdef hash_t ent_id = self.matcher._normalize_key(key)
self._callbacks[ent_id] = on_match
cdef int length
cdef int i cdef int i
for i in range(self.max_length): cdef hash_t phrase_hash
self._phrase_key[i] = 0 for doc in docs:
for i, tag in enumerate(tags): length = doc.length
lexeme = self.vocab[tokens.c[i].lex.orth] tags = get_bilou(length)
lexeme.set_flag(tag, True) for i in range(self.max_length):
self._phrase_key[i] = lexeme.orth self._phrase_key[i] = 0
cdef hash_t key = hash64(self._phrase_key, self.max_length * sizeof(attr_t), 0) for i, tag in enumerate(tags):
self.phrase_ids[key] = True lexeme = self.vocab[doc.c[i].lex.orth]
lexeme.set_flag(tag, True)
self._phrase_key[i] = lexeme.orth
phrase_hash = hash64(self._phrase_key,
self.max_length * sizeof(attr_t), 0)
self.phrase_ids.set(phrase_hash, <void*>ent_id)
def __call__(self, Doc doc): def __call__(self, Doc doc):
matches = [] matches = []
for ent_id, label, start, end in self.matcher(doc): for _, start, end in self.matcher(doc):
cand = doc[start : end] ent_id = self.accept_match(doc, start, end)
start = cand[0].idx if ent_id is not None:
end = cand[-1].idx + len(cand[-1]) matches.append((ent_id, start, end))
matches.append((start, end, cand.root.tag_, cand.text, 'MWE')) for i, (ent_id, start, end) in enumerate(matches):
for match in matches: on_match = self._callbacks.get(ent_id)
doc.merge(*match) if on_match is not None:
on_match(self, doc, i, matches)
return matches return matches
def pipe(self, stream, batch_size=1000, n_threads=2): def pipe(self, stream, batch_size=1000, n_threads=2):
@ -469,7 +491,7 @@ cdef class PhraseMatcher:
self(doc) self(doc)
yield doc yield doc
def accept_match(self, Doc doc, attr_t ent_id, attr_t label, int start, int end): def accept_match(self, Doc doc, int start, int end):
assert (end - start) < self.max_length assert (end - start) < self.max_length
cdef int i, j cdef int i, j
for i in range(self.max_length): for i in range(self.max_length):
@ -477,7 +499,8 @@ cdef class PhraseMatcher:
for i, j in enumerate(range(start, end)): for i, j in enumerate(range(start, end)):
self._phrase_key[i] = doc.c[j].lex.orth self._phrase_key[i] = doc.c[j].lex.orth
cdef hash_t key = hash64(self._phrase_key, self.max_length * sizeof(attr_t), 0) cdef hash_t key = hash64(self._phrase_key, self.max_length * sizeof(attr_t), 0)
if self.phrase_ids.get(key): ent_id = <hash_t>self.phrase_ids.get(key)
return (ent_id, label, start, end) if ent_id == 0:
return None
else: else:
return False return ent_id

View File

@ -34,7 +34,6 @@ def test_matcher_from_api_docs(en_vocab):
assert len(patterns[0]) assert len(patterns[0])
@pytest.mark.xfail
def test_matcher_from_usage_docs(en_vocab): def test_matcher_from_usage_docs(en_vocab):
text = "Wow 😀 This is really cool! 😂 😂" text = "Wow 😀 This is really cool! 😂 😂"
doc = get_doc(en_vocab, words=text.split(' ')) doc = get_doc(en_vocab, words=text.split(' '))
@ -46,7 +45,8 @@ def test_matcher_from_usage_docs(en_vocab):
if doc.vocab.strings[match_id] == 'HAPPY': if doc.vocab.strings[match_id] == 'HAPPY':
doc.sentiment += 0.1 doc.sentiment += 0.1
span = doc[start : end] span = doc[start : end]
token = span.merge(norm='happy emoji') token = span.merge()
token.vocab[token.text].norm_ = 'happy emoji'
matcher = Matcher(en_vocab) matcher = Matcher(en_vocab)
matcher.add('HAPPY', label_sentiment, *pos_patterns) matcher.add('HAPPY', label_sentiment, *pos_patterns)
@ -98,11 +98,11 @@ def test_matcher_match_multi(matcher):
(doc.vocab.strings['Java'], 5, 6)] (doc.vocab.strings['Java'], 5, 6)]
@pytest.mark.xfail
def test_matcher_phrase_matcher(en_vocab): def test_matcher_phrase_matcher(en_vocab):
words = ["Google", "Now"] words = ["Google", "Now"]
doc = get_doc(en_vocab, words) doc = get_doc(en_vocab, words)
matcher = PhraseMatcher(en_vocab, [doc]) matcher = PhraseMatcher(en_vocab)
matcher.add('COMPANY', None, doc)
words = ["I", "like", "Google", "Now", "best"] words = ["I", "like", "Google", "Now", "best"]
doc = get_doc(en_vocab, words) doc = get_doc(en_vocab, words)
assert len(matcher(doc)) == 1 assert len(matcher(doc)) == 1