mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 01:48:04 +03:00 
			
		
		
		
	* Fix multi word matcher
This commit is contained in:
		
							parent
							
								
									801d55a6d9
								
							
						
					
					
						commit
						4bbc8f45c6
					
				| 
						 | 
				
			
			@ -26,137 +26,71 @@ from ast import literal_eval
 | 
			
		|||
from bz2 import BZ2File
 | 
			
		||||
import time
 | 
			
		||||
import math
 | 
			
		||||
import codecs
 | 
			
		||||
 | 
			
		||||
import plac
 | 
			
		||||
 | 
			
		||||
from preshed.maps import PreshMap
 | 
			
		||||
from preshed.counter import PreshCounter
 | 
			
		||||
from spacy.strings import hash_string
 | 
			
		||||
from spacy.en import English
 | 
			
		||||
from spacy.matcher import Matcher
 | 
			
		||||
 | 
			
		||||
from spacy.attrs import FLAG63 as B_ENT
 | 
			
		||||
from spacy.attrs import FLAG62 as L_ENT
 | 
			
		||||
from spacy.attrs import FLAG61 as I_ENT
 | 
			
		||||
 | 
			
		||||
from spacy.attrs import FLAG60 as B2_ENT
 | 
			
		||||
from spacy.attrs import FLAG59 as B3_ENT
 | 
			
		||||
from spacy.attrs import FLAG58 as B4_ENT
 | 
			
		||||
from spacy.attrs import FLAG57 as B5_ENT
 | 
			
		||||
from spacy.attrs import FLAG56 as B6_ENT
 | 
			
		||||
from spacy.attrs import FLAG55 as B7_ENT
 | 
			
		||||
from spacy.attrs import FLAG54 as B8_ENT
 | 
			
		||||
from spacy.attrs import FLAG53 as B9_ENT
 | 
			
		||||
from spacy.attrs import FLAG52 as B10_ENT
 | 
			
		||||
 | 
			
		||||
from spacy.attrs import FLAG51 as I3_ENT
 | 
			
		||||
from spacy.attrs import FLAG50 as I4_ENT
 | 
			
		||||
from spacy.attrs import FLAG49 as I5_ENT
 | 
			
		||||
from spacy.attrs import FLAG48 as I6_ENT
 | 
			
		||||
from spacy.attrs import FLAG47 as I7_ENT
 | 
			
		||||
from spacy.attrs import FLAG46 as I8_ENT
 | 
			
		||||
from spacy.attrs import FLAG45 as I9_ENT
 | 
			
		||||
from spacy.attrs import FLAG44 as I10_ENT
 | 
			
		||||
 | 
			
		||||
from spacy.attrs import FLAG43 as L2_ENT
 | 
			
		||||
from spacy.attrs import FLAG42 as L3_ENT
 | 
			
		||||
from spacy.attrs import FLAG41 as L4_ENT
 | 
			
		||||
from spacy.attrs import FLAG40 as L5_ENT
 | 
			
		||||
from spacy.attrs import FLAG39 as L6_ENT
 | 
			
		||||
from spacy.attrs import FLAG38 as L7_ENT
 | 
			
		||||
from spacy.attrs import FLAG37 as L8_ENT
 | 
			
		||||
from spacy.attrs import FLAG36 as L9_ENT
 | 
			
		||||
from spacy.attrs import FLAG35 as L10_ENT
 | 
			
		||||
from spacy.matcher import PhraseMatcher
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_bilou(length):
 | 
			
		||||
    if length == 1:
 | 
			
		||||
        return [U_ENT]
 | 
			
		||||
    elif length == 2:
 | 
			
		||||
        return [B2_ENT, L2_ENT]
 | 
			
		||||
    elif length == 3:
 | 
			
		||||
        return [B3_ENT, I3_ENT, L3_ENT]
 | 
			
		||||
    elif length == 4:
 | 
			
		||||
        return [B4_ENT, I4_ENT, I4_ENT, L4_ENT]
 | 
			
		||||
    elif length == 5:
 | 
			
		||||
        return [B5_ENT, I5_ENT, I5_ENT, L5_ENT]
 | 
			
		||||
    elif length == 6:
 | 
			
		||||
        return [B6_ENT, I6_ENT, I6_ENT, I6_ENT, I6_ENT, L6_ENT]
 | 
			
		||||
    elif length == 7:
 | 
			
		||||
        return [B7_ENT, I7_ENT, I7_ENT, I7_ENT, I7_ENT, I7_ENT, L7_ENT]
 | 
			
		||||
    elif length == 8:
 | 
			
		||||
        return [B8_ENT, I8_ENT, I8_ENT, I8_ENT, I8_ENT, I8_ENT, I8_ENT, L8_ENT]
 | 
			
		||||
    elif length == 9:
 | 
			
		||||
        return [B9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, L9_ENT]
 | 
			
		||||
    elif length == 10:
 | 
			
		||||
        return [B10_ENT, I10_ENT, I10_ENT, I10_ENT, I10_ENT, I10_ENT, I10_ENT, L10_ENT]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def make_matcher(vocab, max_length):
 | 
			
		||||
    abstract_patterns = []
 | 
			
		||||
    for length in range(2, max_length):
 | 
			
		||||
        abstract_patterns.append([{tag: True} for tag in get_bilou(length)])
 | 
			
		||||
    return Matcher(vocab, {'Candidate': ('CAND', {}, abstract_patterns)})
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_matches(matcher, pattern_ids, doc):
 | 
			
		||||
    matches = []
 | 
			
		||||
    for label, start, end in matcher(doc):
 | 
			
		||||
        candidate = doc[start : end]
 | 
			
		||||
        if pattern_ids[hash_string(candidate.text)] == True:
 | 
			
		||||
            start = candidate[0].idx
 | 
			
		||||
            end = candidate[-1].idx + len(candidate[-1])
 | 
			
		||||
            matches.append((start, end, candidate.root.tag_, candidate.text))
 | 
			
		||||
    return matches
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def merge_matches(doc, matches):
 | 
			
		||||
    for start, end, tag, text in matches:
 | 
			
		||||
        doc.merge(start, end, tag, text, 'MWE')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def read_gazetteer(loc):
 | 
			
		||||
    for line in open(loc):
 | 
			
		||||
def read_gazetteer(tokenizer, loc, n=-1):
 | 
			
		||||
    for i, line in enumerate(open(loc)):
 | 
			
		||||
        phrase = literal_eval('u' + line.strip())
 | 
			
		||||
        if ' (' in phrase and phrase.endswith(')'):
 | 
			
		||||
            phrase = phrase.split(' (', 1)[0]
 | 
			
		||||
        yield phrase
 | 
			
		||||
        if i >= n:
 | 
			
		||||
            break
 | 
			
		||||
        phrase = tokenizer(phrase)
 | 
			
		||||
        if len(phrase) >= 2:
 | 
			
		||||
            yield phrase
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def read_text(bz2_loc):
 | 
			
		||||
    with BZ2File(bz2_loc) as file_:
 | 
			
		||||
        for line in file_:
 | 
			
		||||
            yield line.decode('utf8')
 | 
			
		||||
 | 
			
		||||
def main(patterns_loc, text_loc):
 | 
			
		||||
 | 
			
		||||
def get_matches(tokenizer, phrases, texts, max_length=6):
 | 
			
		||||
    matcher = PhraseMatcher(tokenizer.vocab, phrases, max_length=max_length)
 | 
			
		||||
    print("Match")
 | 
			
		||||
    for text in texts:
 | 
			
		||||
        doc = tokenizer(text)
 | 
			
		||||
        matches = matcher(doc)
 | 
			
		||||
        for mwe in doc.ents:
 | 
			
		||||
            yield mwe
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main(patterns_loc, text_loc, counts_loc, n=10000000):
 | 
			
		||||
    nlp = English(parser=False, tagger=False, entity=False)
 | 
			
		||||
    
 | 
			
		||||
    pattern_ids = PreshMap()
 | 
			
		||||
    max_length = 10
 | 
			
		||||
    i = 0
 | 
			
		||||
    for pattern_str in read_gazetteer(patterns_loc):
 | 
			
		||||
        pattern = nlp.tokenizer(pattern_str)
 | 
			
		||||
        if len(pattern) < 2 or len(pattern) >= max_length:
 | 
			
		||||
            continue
 | 
			
		||||
        bilou_tags = get_bilou(len(pattern))
 | 
			
		||||
        for word, tag in zip(pattern, bilou_tags):
 | 
			
		||||
            lexeme = nlp.vocab[word.orth]
 | 
			
		||||
            lexeme.set_flag(tag, True)
 | 
			
		||||
        pattern_ids[hash_string(pattern.text)] = True
 | 
			
		||||
        i += 1
 | 
			
		||||
        if i >= 10000001:
 | 
			
		||||
            break
 | 
			
		||||
 | 
			
		||||
    matcher = make_matcher(nlp.vocab, max_length)
 | 
			
		||||
 | 
			
		||||
    print("Make matcher")
 | 
			
		||||
    phrases = read_gazetteer(nlp.tokenizer, patterns_loc, n=n)
 | 
			
		||||
    counts = PreshCounter()
 | 
			
		||||
    t1 = time.time()
 | 
			
		||||
        
 | 
			
		||||
    for text in read_text(text_loc):
 | 
			
		||||
        doc = nlp.tokenizer(text)
 | 
			
		||||
        matches = get_matches(matcher, pattern_ids, doc)
 | 
			
		||||
        merge_matches(doc, matches)
 | 
			
		||||
    for mwe in get_matches(nlp.tokenizer, phrases, read_text(text_loc)):
 | 
			
		||||
        counts.inc(hash_string(mwe.text), 1)
 | 
			
		||||
    t2 = time.time()
 | 
			
		||||
    print('10 ^ %d patterns took %d s' % (round(math.log(i, 10)), t2-t1))
 | 
			
		||||
 | 
			
		||||
    print("10m tokens in %d s" % (t2 - t1))
 | 
			
		||||
    
 | 
			
		||||
    with codecs.open(counts_loc, 'w', 'utf8') as file_:
 | 
			
		||||
        for phrase in read_gazetteer(nlp.tokenizer, patterns_loc, n=n):
 | 
			
		||||
            text = phrase.string
 | 
			
		||||
            key = hash_string(text)
 | 
			
		||||
            count = counts[key]
 | 
			
		||||
            if count != 0:
 | 
			
		||||
                file_.write('%d\t%s\n' % (count, text))
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    plac.call(main)
 | 
			
		||||
    if False:
 | 
			
		||||
        import cProfile
 | 
			
		||||
        import pstats
 | 
			
		||||
        cProfile.runctx("plac.call(main)", globals(), locals(), "Profile.prof")
 | 
			
		||||
        s = pstats.Stats("Profile.prof")
 | 
			
		||||
        s.strip_dirs().sort_stats("time").print_stats()
 | 
			
		||||
    else:
 | 
			
		||||
        plac.call(main)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user