mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 01:48:04 +03:00 
			
		
		
		
	Add standalone tagger training example
This commit is contained in:
		
							parent
							
								
									ad590feaa8
								
							
						
					
					
						commit
						8bb443be4f
					
				
							
								
								
									
										150
									
								
								examples/training/train_tagger_ud.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										150
									
								
								examples/training/train_tagger_ud.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,150 @@
 | 
			
		|||
from __future__ import unicode_literals
 | 
			
		||||
from __future__ import print_function
 | 
			
		||||
 | 
			
		||||
import plac
 | 
			
		||||
import codecs
 | 
			
		||||
import spacy.symbols as symbols
 | 
			
		||||
import spacy
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
 | 
			
		||||
from spacy.vocab import Vocab
 | 
			
		||||
from spacy.tagger import Tagger
 | 
			
		||||
from spacy.tokens import Doc
 | 
			
		||||
from spacy.gold import GoldParse
 | 
			
		||||
from spacy.language import Language
 | 
			
		||||
from spacy import orth
 | 
			
		||||
from spacy import attrs
 | 
			
		||||
 | 
			
		||||
import random
 | 
			
		||||
 | 
			
		||||
TAG_MAP = {
 | 
			
		||||
    'ADJ': {symbols.POS: symbols.ADJ},
 | 
			
		||||
    'ADP': {symbols.POS: symbols.ADP},
 | 
			
		||||
    'PUNCT': {symbols.POS: symbols.PUNCT},
 | 
			
		||||
    'ADV': {symbols.POS: symbols.ADV},
 | 
			
		||||
    'AUX': {symbols.POS: symbols.AUX},
 | 
			
		||||
    'SYM': {symbols.POS: symbols.SYM},
 | 
			
		||||
    'INTJ': {symbols.POS: symbols.INTJ},
 | 
			
		||||
    'CCONJ': {symbols.POS: symbols.CCONJ},
 | 
			
		||||
    'X': {symbols.POS: symbols.X},
 | 
			
		||||
    'NOUN': {symbols.POS: symbols.NOUN},
 | 
			
		||||
    'DET': {symbols.POS: symbols.DET},
 | 
			
		||||
    'PROPN': {symbols.POS: symbols.PROPN},
 | 
			
		||||
    'NUM': {symbols.POS: symbols.NUM},
 | 
			
		||||
    'VERB': {symbols.POS: symbols.VERB},
 | 
			
		||||
    'PART': {symbols.POS: symbols.PART},
 | 
			
		||||
  	'PRON': {symbols.POS: symbols.PRON},
 | 
			
		||||
    'SCONJ': {symbols.POS: symbols.SCONJ},
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
LEX_ATTR_GETTERS = {
 | 
			
		||||
    attrs.LOWER: lambda string: string.lower(),
 | 
			
		||||
    attrs.NORM: lambda string: string,
 | 
			
		||||
    attrs.SHAPE: orth.word_shape,
 | 
			
		||||
    attrs.PREFIX: lambda string: string[0],
 | 
			
		||||
    attrs.SUFFIX: lambda string: string[-3:],
 | 
			
		||||
    attrs.CLUSTER: lambda string: 0,
 | 
			
		||||
    attrs.IS_ALPHA: orth.is_alpha,
 | 
			
		||||
    attrs.IS_ASCII: orth.is_ascii,
 | 
			
		||||
    attrs.IS_DIGIT: lambda string: string.isdigit(),
 | 
			
		||||
    attrs.IS_LOWER: orth.is_lower,
 | 
			
		||||
    attrs.IS_PUNCT: orth.is_punct,
 | 
			
		||||
    attrs.IS_SPACE: lambda string: string.isspace(),
 | 
			
		||||
    attrs.IS_TITLE: orth.is_title,
 | 
			
		||||
    attrs.IS_UPPER: orth.is_upper,
 | 
			
		||||
    attrs.IS_BRACKET: orth.is_bracket,
 | 
			
		||||
    attrs.IS_QUOTE: orth.is_quote,
 | 
			
		||||
    attrs.IS_LEFT_PUNCT: orth.is_left_punct,
 | 
			
		||||
    attrs.IS_RIGHT_PUNCT: orth.is_right_punct,
 | 
			
		||||
    attrs.LIKE_URL: orth.like_url,
 | 
			
		||||
    attrs.LIKE_NUM: orth.like_number,
 | 
			
		||||
    attrs.LIKE_EMAIL: orth.like_email,
 | 
			
		||||
    attrs.IS_STOP: lambda string: False,
 | 
			
		||||
    attrs.IS_OOV: lambda string: True
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def read_ud_data(path):
 | 
			
		||||
    data = []
 | 
			
		||||
    last_number = -1
 | 
			
		||||
    sentence_words = []
 | 
			
		||||
    sentence_tags = []
 | 
			
		||||
    with codecs.open(path, encoding="utf-8") as f:
 | 
			
		||||
        while True:
 | 
			
		||||
            line = f.readline()
 | 
			
		||||
            if not line:
 | 
			
		||||
                break
 | 
			
		||||
 | 
			
		||||
            if line[0].isdigit():
 | 
			
		||||
                d = line.split()
 | 
			
		||||
                if not "-" in d[0]:
 | 
			
		||||
                    number = int(line[0])
 | 
			
		||||
                    if number < last_number:
 | 
			
		||||
                        data.append((sentence_words, sentence_tags),)
 | 
			
		||||
                        sentence_words = []
 | 
			
		||||
                        sentence_tags = []
 | 
			
		||||
                    sentence_words.append(d[2])
 | 
			
		||||
                    sentence_tags.append(d[3])
 | 
			
		||||
                    last_number = number
 | 
			
		||||
    if len(sentence_words) > 0:
 | 
			
		||||
        data.append((sentence_words, sentence_tags,))
 | 
			
		||||
    return data
 | 
			
		||||
 | 
			
		||||
def ensure_dir(path):
 | 
			
		||||
    if not path.exists():
 | 
			
		||||
        path.mkdir()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main(train_loc, dev_loc, output_dir=None):
 | 
			
		||||
    if output_dir is not None:
 | 
			
		||||
        output_dir = Path(output_dir)
 | 
			
		||||
        ensure_dir(output_dir)
 | 
			
		||||
        ensure_dir(output_dir / "pos")
 | 
			
		||||
        ensure_dir(output_dir / "vocab")
 | 
			
		||||
 | 
			
		||||
    train_data = read_ud_data(train_loc)
 | 
			
		||||
    vocab = Vocab(tag_map=TAG_MAP, lex_attr_getters=LEX_ATTR_GETTERS)
 | 
			
		||||
    # Populate vocab
 | 
			
		||||
    for words, _ in train_data:
 | 
			
		||||
        for word in words:
 | 
			
		||||
            _ = vocab[word]
 | 
			
		||||
    
 | 
			
		||||
    model = spacy.tagger.TaggerModel(spacy.tagger.Tagger.feature_templates)
 | 
			
		||||
    tagger = Tagger(vocab, model)
 | 
			
		||||
    print(tagger.tag_names)
 | 
			
		||||
    for i in range(30):
 | 
			
		||||
        print("training model (iteration " + str(i) + ")...")
 | 
			
		||||
        score = 0.
 | 
			
		||||
        num_samples = 0.
 | 
			
		||||
        for words, tags in train_data:
 | 
			
		||||
            doc = Doc(vocab, words=words)
 | 
			
		||||
            gold = GoldParse(doc, tags=tags)
 | 
			
		||||
            cost = tagger.update(doc, gold)
 | 
			
		||||
            for i, word in enumerate(doc):
 | 
			
		||||
                num_samples += 1
 | 
			
		||||
                if word.tag_ == tags[i]:
 | 
			
		||||
                    score += 1
 | 
			
		||||
        print('Train acc', score/num_samples) 
 | 
			
		||||
        random.shuffle(train_data)
 | 
			
		||||
    tagger.model.end_training()
 | 
			
		||||
 | 
			
		||||
    score = 0.0
 | 
			
		||||
    test_data = read_ud_data(dev_loc)
 | 
			
		||||
    num_samples = 0
 | 
			
		||||
    for words, tags in test_data:
 | 
			
		||||
        doc = Doc(vocab, words)
 | 
			
		||||
        tagger(doc)
 | 
			
		||||
        for i, word in enumerate(doc):
 | 
			
		||||
            num_samples += 1
 | 
			
		||||
            if word.tag_ == tags[i]:
 | 
			
		||||
                score += 1
 | 
			
		||||
    print("score: " + str(score / num_samples * 100.0))
 | 
			
		||||
    
 | 
			
		||||
    if output_dir is not None:
 | 
			
		||||
        tagger.model.dump(str(output_dir / 'pos' / 'model'))
 | 
			
		||||
        with (output_dir / 'vocab' / 'strings.json').open('w') as file_:
 | 
			
		||||
            tagger.vocab.strings.dump(file_)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    plac.call(main)
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user