diff --git a/examples/multi_word_matches.py b/examples/phrase_matcher.py similarity index 59% rename from examples/multi_word_matches.py rename to examples/phrase_matcher.py index 73f48bf42..ca9b0cc92 100644 --- a/examples/multi_word_matches.py +++ b/examples/phrase_matcher.py @@ -20,72 +20,72 @@ The algorithm is O(n) at run-time for document of length n because we're only ev matching over the tag patterns. So no matter how many phrases we're looking for, our pattern set stays very small (exact size depends on the maximum length we're looking for, as the query language currently has no quantifiers) + +The example expects a .bz2 file from the Reddit corpus, and a patterns file, +formatted in jsonl as a sequence of entries like this: + +{"text":"Anchorage"} +{"text":"Angola"} +{"text":"Ann Arbor"} +{"text":"Annapolis"} +{"text":"Appalachia"} +{"text":"Argentina"} """ from __future__ import print_function, unicode_literals, division -from ast import literal_eval from bz2 import BZ2File import time import math import codecs import plac +import ujson -from preshed.maps import PreshMap -from preshed.counter import PreshCounter -from spacy.strings import hash_string -from spacy.en import English from spacy.matcher import PhraseMatcher +import spacy def read_gazetteer(tokenizer, loc, n=-1): for i, line in enumerate(open(loc)): - phrase = literal_eval('u' + line.strip()) - if ' (' in phrase and phrase.endswith(')'): - phrase = phrase.split(' (', 1)[0] - if i >= n: - break - phrase = tokenizer(phrase) - if all((t.is_lower and t.prob >= -10) for t in phrase): - continue + data = ujson.loads(line.strip()) + phrase = tokenizer(data['text']) + for w in phrase: + _ = tokenizer.vocab[w.text] if len(phrase) >= 2: yield phrase -def read_text(bz2_loc): +def read_text(bz2_loc, n=10000): with BZ2File(bz2_loc) as file_: - for line in file_: - yield line.decode('utf8') + for i, line in enumerate(file_): + data = ujson.loads(line) + yield data['body'] + if i >= n: + break def get_matches(tokenizer, phrases, texts, max_length=6): - matcher = PhraseMatcher(tokenizer.vocab, phrases, max_length=max_length) - print("Match") + matcher = PhraseMatcher(tokenizer.vocab, max_length=max_length) + matcher.add('Phrase', None, *phrases) for text in texts: doc = tokenizer(text) + for w in doc: + _ = doc.vocab[w.text] matches = matcher(doc) - for mwe in doc.ents: - yield mwe + for ent_id, start, end in matches: + yield (ent_id, doc[start:end].text) -def main(patterns_loc, text_loc, counts_loc, n=10000000): - nlp = English(parser=False, tagger=False, entity=False) - print("Make matcher") - phrases = read_gazetteer(nlp.tokenizer, patterns_loc, n=n) - counts = PreshCounter() +def main(patterns_loc, text_loc, n=10000): + nlp = spacy.blank('en') + nlp.vocab.lex_attr_getters = {} + phrases = read_gazetteer(nlp.tokenizer, patterns_loc) + count = 0 t1 = time.time() - for mwe in get_matches(nlp.tokenizer, phrases, read_text(text_loc)): - counts.inc(hash_string(mwe.text), 1) + for ent_id, text in get_matches(nlp.tokenizer, phrases, read_text(text_loc, n=n)): + count += 1 t2 = time.time() - print("10m tokens in %d s" % (t2 - t1)) - - with codecs.open(counts_loc, 'w', 'utf8') as file_: - for phrase in read_gazetteer(nlp.tokenizer, patterns_loc, n=n): - text = phrase.string - key = hash_string(text) - count = counts[key] - if count != 0: - file_.write('%d\t%s\n' % (count, text)) - + print("%d docs in %.3f s. %d matches" % (n, (t2 - t1), count)) + if __name__ == '__main__': if False: diff --git a/spacy/matcher.pyx b/spacy/matcher.pyx index c75d23957..3bc6f859c 100644 --- a/spacy/matcher.pyx +++ b/spacy/matcher.pyx @@ -421,47 +421,69 @@ cdef class PhraseMatcher: cdef int max_length cdef attr_t* _phrase_key - def __init__(self, Vocab vocab, phrases, max_length=10): + cdef public object _callbacks + cdef public object _patterns + + def __init__(self, Vocab vocab, max_length=10): self.mem = Pool() self._phrase_key = self.mem.alloc(max_length, sizeof(attr_t)) self.max_length = max_length self.vocab = vocab - self.matcher = Matcher(self.vocab, {}) + self.matcher = Matcher(self.vocab) self.phrase_ids = PreshMap() - for phrase in phrases: - if len(phrase) < max_length: - self.add(phrase) - abstract_patterns = [] for length in range(1, max_length): abstract_patterns.append([{tag: True} for tag in get_bilou(length)]) - self.matcher.add('Candidate', 'MWE', {}, abstract_patterns, acceptor=self.accept_match) + self.matcher.add('Candidate', None, *abstract_patterns) + self._callbacks = {} - def add(self, Doc tokens): - cdef int length = tokens.length - assert length < self.max_length - tags = get_bilou(length) - assert len(tags) == length, length + def __len__(self): + raise NotImplementedError + def __contains__(self, key): + raise NotImplementedError + + def __reduce__(self): + return (self.__class__, (self.vocab,), None, None) + + def add(self, key, on_match, *docs): + cdef Doc doc + for doc in docs: + if len(doc) >= self.max_length: + msg = ( + "Pattern length (%d) >= phrase_matcher.max_length (%d). " + "Length can be set on initialization, up to 10." + ) + raise ValueError(msg % (len(doc), self.max_length)) + cdef hash_t ent_id = self.matcher._normalize_key(key) + self._callbacks[ent_id] = on_match + + cdef int length cdef int i - for i in range(self.max_length): - self._phrase_key[i] = 0 - for i, tag in enumerate(tags): - lexeme = self.vocab[tokens.c[i].lex.orth] - lexeme.set_flag(tag, True) - self._phrase_key[i] = lexeme.orth - cdef hash_t key = hash64(self._phrase_key, self.max_length * sizeof(attr_t), 0) - self.phrase_ids[key] = True + cdef hash_t phrase_hash + for doc in docs: + length = doc.length + tags = get_bilou(length) + for i in range(self.max_length): + self._phrase_key[i] = 0 + for i, tag in enumerate(tags): + lexeme = self.vocab[doc.c[i].lex.orth] + lexeme.set_flag(tag, True) + self._phrase_key[i] = lexeme.orth + phrase_hash = hash64(self._phrase_key, + self.max_length * sizeof(attr_t), 0) + self.phrase_ids.set(phrase_hash, ent_id) def __call__(self, Doc doc): matches = [] - for ent_id, label, start, end in self.matcher(doc): - cand = doc[start : end] - start = cand[0].idx - end = cand[-1].idx + len(cand[-1]) - matches.append((start, end, cand.root.tag_, cand.text, 'MWE')) - for match in matches: - doc.merge(*match) + for _, start, end in self.matcher(doc): + ent_id = self.accept_match(doc, start, end) + if ent_id is not None: + matches.append((ent_id, start, end)) + for i, (ent_id, start, end) in enumerate(matches): + on_match = self._callbacks.get(ent_id) + if on_match is not None: + on_match(self, doc, i, matches) return matches def pipe(self, stream, batch_size=1000, n_threads=2): @@ -469,7 +491,7 @@ cdef class PhraseMatcher: self(doc) yield doc - def accept_match(self, Doc doc, attr_t ent_id, attr_t label, int start, int end): + def accept_match(self, Doc doc, int start, int end): assert (end - start) < self.max_length cdef int i, j for i in range(self.max_length): @@ -477,7 +499,8 @@ cdef class PhraseMatcher: for i, j in enumerate(range(start, end)): self._phrase_key[i] = doc.c[j].lex.orth cdef hash_t key = hash64(self._phrase_key, self.max_length * sizeof(attr_t), 0) - if self.phrase_ids.get(key): - return (ent_id, label, start, end) + ent_id = self.phrase_ids.get(key) + if ent_id == 0: + return None else: - return False + return ent_id diff --git a/spacy/tests/test_matcher.py b/spacy/tests/test_matcher.py index 388aab03e..1b9f92519 100644 --- a/spacy/tests/test_matcher.py +++ b/spacy/tests/test_matcher.py @@ -34,7 +34,6 @@ def test_matcher_from_api_docs(en_vocab): assert len(patterns[0]) -@pytest.mark.xfail def test_matcher_from_usage_docs(en_vocab): text = "Wow 😀 This is really cool! 😂 😂" doc = get_doc(en_vocab, words=text.split(' ')) @@ -46,7 +45,8 @@ def test_matcher_from_usage_docs(en_vocab): if doc.vocab.strings[match_id] == 'HAPPY': doc.sentiment += 0.1 span = doc[start : end] - token = span.merge(norm='happy emoji') + token = span.merge() + token.vocab[token.text].norm_ = 'happy emoji' matcher = Matcher(en_vocab) matcher.add('HAPPY', label_sentiment, *pos_patterns) @@ -98,11 +98,11 @@ def test_matcher_match_multi(matcher): (doc.vocab.strings['Java'], 5, 6)] -@pytest.mark.xfail def test_matcher_phrase_matcher(en_vocab): words = ["Google", "Now"] doc = get_doc(en_vocab, words) - matcher = PhraseMatcher(en_vocab, [doc]) + matcher = PhraseMatcher(en_vocab) + matcher.add('COMPANY', None, doc) words = ["I", "like", "Google", "Now", "best"] doc = get_doc(en_vocab, words) assert len(matcher(doc)) == 1