mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-10 19:57:17 +03:00
Merge pull request #1343 from explosion/feature/phrasematcher
Update PhraseMatcher for spaCy 2
This commit is contained in:
commit
1ef4236f8e
|
@ -20,72 +20,72 @@ The algorithm is O(n) at run-time for document of length n because we're only ev
|
|||
matching over the tag patterns. So no matter how many phrases we're looking for,
|
||||
our pattern set stays very small (exact size depends on the maximum length we're
|
||||
looking for, as the query language currently has no quantifiers)
|
||||
|
||||
The example expects a .bz2 file from the Reddit corpus, and a patterns file,
|
||||
formatted in jsonl as a sequence of entries like this:
|
||||
|
||||
{"text":"Anchorage"}
|
||||
{"text":"Angola"}
|
||||
{"text":"Ann Arbor"}
|
||||
{"text":"Annapolis"}
|
||||
{"text":"Appalachia"}
|
||||
{"text":"Argentina"}
|
||||
"""
|
||||
from __future__ import print_function, unicode_literals, division
|
||||
from ast import literal_eval
|
||||
from bz2 import BZ2File
|
||||
import time
|
||||
import math
|
||||
import codecs
|
||||
|
||||
import plac
|
||||
import ujson
|
||||
|
||||
from preshed.maps import PreshMap
|
||||
from preshed.counter import PreshCounter
|
||||
from spacy.strings import hash_string
|
||||
from spacy.en import English
|
||||
from spacy.matcher import PhraseMatcher
|
||||
import spacy
|
||||
|
||||
|
||||
def read_gazetteer(tokenizer, loc, n=-1):
|
||||
for i, line in enumerate(open(loc)):
|
||||
phrase = literal_eval('u' + line.strip())
|
||||
if ' (' in phrase and phrase.endswith(')'):
|
||||
phrase = phrase.split(' (', 1)[0]
|
||||
if i >= n:
|
||||
break
|
||||
phrase = tokenizer(phrase)
|
||||
if all((t.is_lower and t.prob >= -10) for t in phrase):
|
||||
continue
|
||||
data = ujson.loads(line.strip())
|
||||
phrase = tokenizer(data['text'])
|
||||
for w in phrase:
|
||||
_ = tokenizer.vocab[w.text]
|
||||
if len(phrase) >= 2:
|
||||
yield phrase
|
||||
|
||||
|
||||
def read_text(bz2_loc):
|
||||
def read_text(bz2_loc, n=10000):
|
||||
with BZ2File(bz2_loc) as file_:
|
||||
for line in file_:
|
||||
yield line.decode('utf8')
|
||||
for i, line in enumerate(file_):
|
||||
data = ujson.loads(line)
|
||||
yield data['body']
|
||||
if i >= n:
|
||||
break
|
||||
|
||||
|
||||
def get_matches(tokenizer, phrases, texts, max_length=6):
|
||||
matcher = PhraseMatcher(tokenizer.vocab, phrases, max_length=max_length)
|
||||
print("Match")
|
||||
matcher = PhraseMatcher(tokenizer.vocab, max_length=max_length)
|
||||
matcher.add('Phrase', None, *phrases)
|
||||
for text in texts:
|
||||
doc = tokenizer(text)
|
||||
for w in doc:
|
||||
_ = doc.vocab[w.text]
|
||||
matches = matcher(doc)
|
||||
for mwe in doc.ents:
|
||||
yield mwe
|
||||
for ent_id, start, end in matches:
|
||||
yield (ent_id, doc[start:end].text)
|
||||
|
||||
|
||||
def main(patterns_loc, text_loc, counts_loc, n=10000000):
|
||||
nlp = English(parser=False, tagger=False, entity=False)
|
||||
print("Make matcher")
|
||||
phrases = read_gazetteer(nlp.tokenizer, patterns_loc, n=n)
|
||||
counts = PreshCounter()
|
||||
def main(patterns_loc, text_loc, n=10000):
|
||||
nlp = spacy.blank('en')
|
||||
nlp.vocab.lex_attr_getters = {}
|
||||
phrases = read_gazetteer(nlp.tokenizer, patterns_loc)
|
||||
count = 0
|
||||
t1 = time.time()
|
||||
for mwe in get_matches(nlp.tokenizer, phrases, read_text(text_loc)):
|
||||
counts.inc(hash_string(mwe.text), 1)
|
||||
for ent_id, text in get_matches(nlp.tokenizer, phrases, read_text(text_loc, n=n)):
|
||||
count += 1
|
||||
t2 = time.time()
|
||||
print("10m tokens in %d s" % (t2 - t1))
|
||||
|
||||
with codecs.open(counts_loc, 'w', 'utf8') as file_:
|
||||
for phrase in read_gazetteer(nlp.tokenizer, patterns_loc, n=n):
|
||||
text = phrase.string
|
||||
key = hash_string(text)
|
||||
count = counts[key]
|
||||
if count != 0:
|
||||
file_.write('%d\t%s\n' % (count, text))
|
||||
|
||||
print("%d docs in %.3f s. %d matches" % (n, (t2 - t1), count))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if False:
|
|
@ -421,47 +421,69 @@ cdef class PhraseMatcher:
|
|||
cdef int max_length
|
||||
cdef attr_t* _phrase_key
|
||||
|
||||
def __init__(self, Vocab vocab, phrases, max_length=10):
|
||||
cdef public object _callbacks
|
||||
cdef public object _patterns
|
||||
|
||||
def __init__(self, Vocab vocab, max_length=10):
|
||||
self.mem = Pool()
|
||||
self._phrase_key = <attr_t*>self.mem.alloc(max_length, sizeof(attr_t))
|
||||
self.max_length = max_length
|
||||
self.vocab = vocab
|
||||
self.matcher = Matcher(self.vocab, {})
|
||||
self.matcher = Matcher(self.vocab)
|
||||
self.phrase_ids = PreshMap()
|
||||
for phrase in phrases:
|
||||
if len(phrase) < max_length:
|
||||
self.add(phrase)
|
||||
|
||||
abstract_patterns = []
|
||||
for length in range(1, max_length):
|
||||
abstract_patterns.append([{tag: True} for tag in get_bilou(length)])
|
||||
self.matcher.add('Candidate', 'MWE', {}, abstract_patterns, acceptor=self.accept_match)
|
||||
self.matcher.add('Candidate', None, *abstract_patterns)
|
||||
self._callbacks = {}
|
||||
|
||||
def add(self, Doc tokens):
|
||||
cdef int length = tokens.length
|
||||
assert length < self.max_length
|
||||
tags = get_bilou(length)
|
||||
assert len(tags) == length, length
|
||||
def __len__(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def __contains__(self, key):
|
||||
raise NotImplementedError
|
||||
|
||||
def __reduce__(self):
|
||||
return (self.__class__, (self.vocab,), None, None)
|
||||
|
||||
def add(self, key, on_match, *docs):
|
||||
cdef Doc doc
|
||||
for doc in docs:
|
||||
if len(doc) >= self.max_length:
|
||||
msg = (
|
||||
"Pattern length (%d) >= phrase_matcher.max_length (%d). "
|
||||
"Length can be set on initialization, up to 10."
|
||||
)
|
||||
raise ValueError(msg % (len(doc), self.max_length))
|
||||
cdef hash_t ent_id = self.matcher._normalize_key(key)
|
||||
self._callbacks[ent_id] = on_match
|
||||
|
||||
cdef int length
|
||||
cdef int i
|
||||
for i in range(self.max_length):
|
||||
self._phrase_key[i] = 0
|
||||
for i, tag in enumerate(tags):
|
||||
lexeme = self.vocab[tokens.c[i].lex.orth]
|
||||
lexeme.set_flag(tag, True)
|
||||
self._phrase_key[i] = lexeme.orth
|
||||
cdef hash_t key = hash64(self._phrase_key, self.max_length * sizeof(attr_t), 0)
|
||||
self.phrase_ids[key] = True
|
||||
cdef hash_t phrase_hash
|
||||
for doc in docs:
|
||||
length = doc.length
|
||||
tags = get_bilou(length)
|
||||
for i in range(self.max_length):
|
||||
self._phrase_key[i] = 0
|
||||
for i, tag in enumerate(tags):
|
||||
lexeme = self.vocab[doc.c[i].lex.orth]
|
||||
lexeme.set_flag(tag, True)
|
||||
self._phrase_key[i] = lexeme.orth
|
||||
phrase_hash = hash64(self._phrase_key,
|
||||
self.max_length * sizeof(attr_t), 0)
|
||||
self.phrase_ids.set(phrase_hash, <void*>ent_id)
|
||||
|
||||
def __call__(self, Doc doc):
|
||||
matches = []
|
||||
for ent_id, label, start, end in self.matcher(doc):
|
||||
cand = doc[start : end]
|
||||
start = cand[0].idx
|
||||
end = cand[-1].idx + len(cand[-1])
|
||||
matches.append((start, end, cand.root.tag_, cand.text, 'MWE'))
|
||||
for match in matches:
|
||||
doc.merge(*match)
|
||||
for _, start, end in self.matcher(doc):
|
||||
ent_id = self.accept_match(doc, start, end)
|
||||
if ent_id is not None:
|
||||
matches.append((ent_id, start, end))
|
||||
for i, (ent_id, start, end) in enumerate(matches):
|
||||
on_match = self._callbacks.get(ent_id)
|
||||
if on_match is not None:
|
||||
on_match(self, doc, i, matches)
|
||||
return matches
|
||||
|
||||
def pipe(self, stream, batch_size=1000, n_threads=2):
|
||||
|
@ -469,7 +491,7 @@ cdef class PhraseMatcher:
|
|||
self(doc)
|
||||
yield doc
|
||||
|
||||
def accept_match(self, Doc doc, attr_t ent_id, attr_t label, int start, int end):
|
||||
def accept_match(self, Doc doc, int start, int end):
|
||||
assert (end - start) < self.max_length
|
||||
cdef int i, j
|
||||
for i in range(self.max_length):
|
||||
|
@ -477,7 +499,8 @@ cdef class PhraseMatcher:
|
|||
for i, j in enumerate(range(start, end)):
|
||||
self._phrase_key[i] = doc.c[j].lex.orth
|
||||
cdef hash_t key = hash64(self._phrase_key, self.max_length * sizeof(attr_t), 0)
|
||||
if self.phrase_ids.get(key):
|
||||
return (ent_id, label, start, end)
|
||||
ent_id = <hash_t>self.phrase_ids.get(key)
|
||||
if ent_id == 0:
|
||||
return None
|
||||
else:
|
||||
return False
|
||||
return ent_id
|
||||
|
|
|
@ -34,7 +34,6 @@ def test_matcher_from_api_docs(en_vocab):
|
|||
assert len(patterns[0])
|
||||
|
||||
|
||||
@pytest.mark.xfail
|
||||
def test_matcher_from_usage_docs(en_vocab):
|
||||
text = "Wow 😀 This is really cool! 😂 😂"
|
||||
doc = get_doc(en_vocab, words=text.split(' '))
|
||||
|
@ -46,7 +45,8 @@ def test_matcher_from_usage_docs(en_vocab):
|
|||
if doc.vocab.strings[match_id] == 'HAPPY':
|
||||
doc.sentiment += 0.1
|
||||
span = doc[start : end]
|
||||
token = span.merge(norm='happy emoji')
|
||||
token = span.merge()
|
||||
token.vocab[token.text].norm_ = 'happy emoji'
|
||||
|
||||
matcher = Matcher(en_vocab)
|
||||
matcher.add('HAPPY', label_sentiment, *pos_patterns)
|
||||
|
@ -98,11 +98,11 @@ def test_matcher_match_multi(matcher):
|
|||
(doc.vocab.strings['Java'], 5, 6)]
|
||||
|
||||
|
||||
@pytest.mark.xfail
|
||||
def test_matcher_phrase_matcher(en_vocab):
|
||||
words = ["Google", "Now"]
|
||||
doc = get_doc(en_vocab, words)
|
||||
matcher = PhraseMatcher(en_vocab, [doc])
|
||||
matcher = PhraseMatcher(en_vocab)
|
||||
matcher.add('COMPANY', None, doc)
|
||||
words = ["I", "like", "Google", "Now", "best"]
|
||||
doc = get_doc(en_vocab, words)
|
||||
assert len(matcher(doc)) == 1
|
||||
|
|
Loading…
Reference in New Issue
Block a user