mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 07:57:35 +03:00 
			
		
		
		
	Merge pull request #1343 from explosion/feature/phrasematcher
Update PhraseMatcher for spaCy 2
This commit is contained in:
		
						commit
						1ef4236f8e
					
				|  | @ -20,71 +20,71 @@ The algorithm is O(n) at run-time for document of length n because we're only ev | |||
| matching over the tag patterns. So no matter how many phrases we're looking for, | ||||
| our pattern set stays very small (exact size depends on the maximum length we're | ||||
| looking for, as the query language currently has no quantifiers) | ||||
| 
 | ||||
| The example expects a .bz2 file from the Reddit corpus, and a patterns file, | ||||
| formatted in jsonl as a sequence of entries like this: | ||||
| 
 | ||||
| {"text":"Anchorage"} | ||||
| {"text":"Angola"} | ||||
| {"text":"Ann Arbor"} | ||||
| {"text":"Annapolis"} | ||||
| {"text":"Appalachia"} | ||||
| {"text":"Argentina"} | ||||
| """ | ||||
| from __future__ import print_function, unicode_literals, division | ||||
| from ast import literal_eval | ||||
| from bz2 import BZ2File | ||||
| import time | ||||
| import math | ||||
| import codecs | ||||
| 
 | ||||
| import plac | ||||
| import ujson | ||||
| 
 | ||||
| from preshed.maps import PreshMap | ||||
| from preshed.counter import PreshCounter | ||||
| from spacy.strings import hash_string | ||||
| from spacy.en import English | ||||
| from spacy.matcher import PhraseMatcher | ||||
| import spacy | ||||
| 
 | ||||
| 
 | ||||
| def read_gazetteer(tokenizer, loc, n=-1): | ||||
|     for i, line in enumerate(open(loc)): | ||||
|         phrase = literal_eval('u' + line.strip()) | ||||
|         if ' (' in phrase and phrase.endswith(')'): | ||||
|             phrase = phrase.split(' (', 1)[0] | ||||
|         if i >= n: | ||||
|             break | ||||
|         phrase = tokenizer(phrase) | ||||
|         if all((t.is_lower and t.prob >= -10) for t in phrase): | ||||
|             continue | ||||
|         data = ujson.loads(line.strip()) | ||||
|         phrase = tokenizer(data['text']) | ||||
|         for w in phrase: | ||||
|             _ = tokenizer.vocab[w.text] | ||||
|         if len(phrase) >= 2: | ||||
|             yield phrase | ||||
| 
 | ||||
| 
 | ||||
| def read_text(bz2_loc): | ||||
| def read_text(bz2_loc, n=10000): | ||||
|     with BZ2File(bz2_loc) as file_: | ||||
|         for line in file_: | ||||
|             yield line.decode('utf8') | ||||
|         for i, line in enumerate(file_): | ||||
|             data = ujson.loads(line) | ||||
|             yield data['body'] | ||||
|             if i >= n: | ||||
|                 break | ||||
| 
 | ||||
| 
 | ||||
| def get_matches(tokenizer, phrases, texts, max_length=6): | ||||
|     matcher = PhraseMatcher(tokenizer.vocab, phrases, max_length=max_length) | ||||
|     print("Match") | ||||
|     matcher = PhraseMatcher(tokenizer.vocab, max_length=max_length) | ||||
|     matcher.add('Phrase', None, *phrases) | ||||
|     for text in texts: | ||||
|         doc = tokenizer(text) | ||||
|         for w in doc: | ||||
|             _ = doc.vocab[w.text] | ||||
|         matches = matcher(doc) | ||||
|         for mwe in doc.ents: | ||||
|             yield mwe | ||||
|         for ent_id, start, end in matches: | ||||
|             yield (ent_id, doc[start:end].text) | ||||
| 
 | ||||
| 
 | ||||
| def main(patterns_loc, text_loc, counts_loc, n=10000000): | ||||
|     nlp = English(parser=False, tagger=False, entity=False) | ||||
|     print("Make matcher") | ||||
|     phrases = read_gazetteer(nlp.tokenizer, patterns_loc, n=n) | ||||
|     counts = PreshCounter() | ||||
| def main(patterns_loc, text_loc, n=10000): | ||||
|     nlp = spacy.blank('en') | ||||
|     nlp.vocab.lex_attr_getters = {} | ||||
|     phrases = read_gazetteer(nlp.tokenizer, patterns_loc) | ||||
|     count = 0 | ||||
|     t1 = time.time() | ||||
|     for mwe in get_matches(nlp.tokenizer, phrases, read_text(text_loc)): | ||||
|         counts.inc(hash_string(mwe.text), 1) | ||||
|     for ent_id, text in get_matches(nlp.tokenizer, phrases, read_text(text_loc, n=n)): | ||||
|         count += 1 | ||||
|     t2 = time.time() | ||||
|     print("10m tokens in %d s" % (t2 - t1)) | ||||
|      | ||||
|     with codecs.open(counts_loc, 'w', 'utf8') as file_: | ||||
|         for phrase in read_gazetteer(nlp.tokenizer, patterns_loc, n=n): | ||||
|             text = phrase.string | ||||
|             key = hash_string(text) | ||||
|             count = counts[key] | ||||
|             if count != 0: | ||||
|                 file_.write('%d\t%s\n' % (count, text)) | ||||
|     print("%d docs in %.3f s. %d matches" % (n, (t2 - t1), count)) | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
|  | @ -421,47 +421,69 @@ cdef class PhraseMatcher: | |||
|     cdef int max_length | ||||
|     cdef attr_t* _phrase_key | ||||
| 
 | ||||
|     def __init__(self, Vocab vocab, phrases, max_length=10): | ||||
|     cdef public object _callbacks | ||||
|     cdef public object _patterns | ||||
| 
 | ||||
|     def __init__(self, Vocab vocab, max_length=10): | ||||
|         self.mem = Pool() | ||||
|         self._phrase_key = <attr_t*>self.mem.alloc(max_length, sizeof(attr_t)) | ||||
|         self.max_length = max_length | ||||
|         self.vocab = vocab | ||||
|         self.matcher = Matcher(self.vocab, {}) | ||||
|         self.matcher = Matcher(self.vocab) | ||||
|         self.phrase_ids = PreshMap() | ||||
|         for phrase in phrases: | ||||
|             if len(phrase) < max_length: | ||||
|                 self.add(phrase) | ||||
| 
 | ||||
|         abstract_patterns = [] | ||||
|         for length in range(1, max_length): | ||||
|             abstract_patterns.append([{tag: True} for tag in get_bilou(length)]) | ||||
|         self.matcher.add('Candidate', 'MWE', {}, abstract_patterns, acceptor=self.accept_match) | ||||
|         self.matcher.add('Candidate', None, *abstract_patterns) | ||||
|         self._callbacks = {} | ||||
| 
 | ||||
|     def add(self, Doc tokens): | ||||
|         cdef int length = tokens.length | ||||
|         assert length < self.max_length | ||||
|         tags = get_bilou(length) | ||||
|         assert len(tags) == length, length | ||||
|     def __len__(self): | ||||
|         raise NotImplementedError | ||||
| 
 | ||||
|     def __contains__(self, key): | ||||
|         raise NotImplementedError | ||||
| 
 | ||||
|     def __reduce__(self): | ||||
|         return (self.__class__, (self.vocab,), None, None) | ||||
| 
 | ||||
|     def add(self, key, on_match, *docs): | ||||
|         cdef Doc doc | ||||
|         for doc in docs: | ||||
|             if len(doc) >= self.max_length: | ||||
|                 msg = ( | ||||
|                     "Pattern length (%d) >= phrase_matcher.max_length (%d). " | ||||
|                     "Length can be set on initialization, up to 10." | ||||
|                 ) | ||||
|                 raise ValueError(msg % (len(doc), self.max_length)) | ||||
|         cdef hash_t ent_id = self.matcher._normalize_key(key) | ||||
|         self._callbacks[ent_id] = on_match | ||||
| 
 | ||||
|         cdef int length | ||||
|         cdef int i | ||||
|         cdef hash_t phrase_hash | ||||
|         for doc in docs: | ||||
|             length = doc.length | ||||
|             tags = get_bilou(length) | ||||
|             for i in range(self.max_length): | ||||
|                 self._phrase_key[i] = 0 | ||||
|             for i, tag in enumerate(tags): | ||||
|             lexeme = self.vocab[tokens.c[i].lex.orth] | ||||
|                 lexeme = self.vocab[doc.c[i].lex.orth] | ||||
|                 lexeme.set_flag(tag, True) | ||||
|                 self._phrase_key[i] = lexeme.orth | ||||
|         cdef hash_t key = hash64(self._phrase_key, self.max_length * sizeof(attr_t), 0) | ||||
|         self.phrase_ids[key] = True | ||||
|             phrase_hash = hash64(self._phrase_key, | ||||
|                                  self.max_length * sizeof(attr_t), 0) | ||||
|             self.phrase_ids.set(phrase_hash, <void*>ent_id) | ||||
| 
 | ||||
|     def __call__(self, Doc doc): | ||||
|         matches = [] | ||||
|         for ent_id, label, start, end in self.matcher(doc): | ||||
|             cand = doc[start : end] | ||||
|             start = cand[0].idx | ||||
|             end = cand[-1].idx + len(cand[-1]) | ||||
|             matches.append((start, end, cand.root.tag_, cand.text, 'MWE')) | ||||
|         for match in matches: | ||||
|             doc.merge(*match) | ||||
|         for _, start, end in self.matcher(doc): | ||||
|             ent_id = self.accept_match(doc, start, end) | ||||
|             if ent_id is not None: | ||||
|                 matches.append((ent_id, start, end)) | ||||
|         for i, (ent_id, start, end) in enumerate(matches): | ||||
|             on_match = self._callbacks.get(ent_id) | ||||
|             if on_match is not None: | ||||
|                 on_match(self, doc, i, matches) | ||||
|         return matches | ||||
| 
 | ||||
|     def pipe(self, stream, batch_size=1000, n_threads=2): | ||||
|  | @ -469,7 +491,7 @@ cdef class PhraseMatcher: | |||
|             self(doc) | ||||
|             yield doc | ||||
| 
 | ||||
|     def accept_match(self, Doc doc, attr_t ent_id, attr_t label, int start, int end): | ||||
|     def accept_match(self, Doc doc, int start, int end): | ||||
|         assert (end - start) < self.max_length | ||||
|         cdef int i, j | ||||
|         for i in range(self.max_length): | ||||
|  | @ -477,7 +499,8 @@ cdef class PhraseMatcher: | |||
|         for i, j in enumerate(range(start, end)): | ||||
|             self._phrase_key[i] = doc.c[j].lex.orth | ||||
|         cdef hash_t key = hash64(self._phrase_key, self.max_length * sizeof(attr_t), 0) | ||||
|         if self.phrase_ids.get(key): | ||||
|             return (ent_id, label, start, end) | ||||
|         ent_id = <hash_t>self.phrase_ids.get(key) | ||||
|         if ent_id == 0: | ||||
|             return None | ||||
|         else: | ||||
|             return False | ||||
|             return ent_id | ||||
|  |  | |||
|  | @ -34,7 +34,6 @@ def test_matcher_from_api_docs(en_vocab): | |||
|     assert len(patterns[0]) | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.xfail | ||||
| def test_matcher_from_usage_docs(en_vocab): | ||||
|     text = "Wow 😀 This is really cool! 😂 😂" | ||||
|     doc = get_doc(en_vocab, words=text.split(' ')) | ||||
|  | @ -46,7 +45,8 @@ def test_matcher_from_usage_docs(en_vocab): | |||
|         if doc.vocab.strings[match_id] == 'HAPPY': | ||||
|             doc.sentiment += 0.1 | ||||
|         span = doc[start : end] | ||||
|         token = span.merge(norm='happy emoji') | ||||
|         token = span.merge() | ||||
|         token.vocab[token.text].norm_ = 'happy emoji' | ||||
| 
 | ||||
|     matcher = Matcher(en_vocab) | ||||
|     matcher.add('HAPPY', label_sentiment, *pos_patterns) | ||||
|  | @ -98,11 +98,11 @@ def test_matcher_match_multi(matcher): | |||
|                             (doc.vocab.strings['Java'], 5, 6)] | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.xfail | ||||
| def test_matcher_phrase_matcher(en_vocab): | ||||
|     words = ["Google", "Now"] | ||||
|     doc = get_doc(en_vocab, words) | ||||
|     matcher = PhraseMatcher(en_vocab, [doc]) | ||||
|     matcher = PhraseMatcher(en_vocab) | ||||
|     matcher.add('COMPANY', None, doc) | ||||
|     words = ["I", "like", "Google", "Now", "best"] | ||||
|     doc = get_doc(en_vocab, words) | ||||
|     assert len(matcher(doc)) == 1 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user