mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 09:57:26 +03:00 
			
		
		
		
	
		
			
				
	
	
		
			131 lines
		
	
	
		
			3.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			131 lines
		
	
	
		
			3.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
#!/usr/bin/env python
 | 
						|
from __future__ import division
 | 
						|
from __future__ import unicode_literals
 | 
						|
 | 
						|
import os
 | 
						|
from os import path
 | 
						|
import shutil
 | 
						|
import codecs
 | 
						|
import random
 | 
						|
import time
 | 
						|
import gzip
 | 
						|
 | 
						|
import plac
 | 
						|
import cProfile
 | 
						|
import pstats
 | 
						|
 | 
						|
import spacy.util
 | 
						|
from spacy.en import English
 | 
						|
from spacy.en.pos import POS_TEMPLATES, POS_TAGS, setup_model_dir
 | 
						|
 | 
						|
from spacy.syntax.parser import GreedyParser
 | 
						|
from spacy.syntax.parser import OracleError
 | 
						|
from spacy.syntax.util import Config
 | 
						|
 | 
						|
 | 
						|
def is_punct_label(label):
 | 
						|
    return label == 'P' or label.lower() == 'punct'
 | 
						|
 | 
						|
 | 
						|
def read_gold(file_):
 | 
						|
    """Read a standard CoNLL/MALT-style format"""
 | 
						|
    sents = []
 | 
						|
    for sent_str in file_.read().strip().split('\n\n'):
 | 
						|
        ids = []
 | 
						|
        words = []
 | 
						|
        heads = []
 | 
						|
        labels = []
 | 
						|
        tags = []
 | 
						|
        for i, line in enumerate(sent_str.split('\n')):
 | 
						|
            id_, word, pos_string, head_idx, label = _parse_line(line)
 | 
						|
            words.append(word)
 | 
						|
            if head_idx == -1:
 | 
						|
                head_idx = i
 | 
						|
            ids.append(id_)
 | 
						|
            heads.append(head_idx)
 | 
						|
            labels.append(label)
 | 
						|
            tags.append(pos_string)
 | 
						|
        text = ' '.join(words)
 | 
						|
        sents.append((text, [words], ids, words, tags, heads, labels))
 | 
						|
    return sents
 | 
						|
 | 
						|
 | 
						|
def _parse_line(line):
 | 
						|
    pieces = line.split()
 | 
						|
    id_ = int(pieces[0])
 | 
						|
    word = pieces[1]
 | 
						|
    pos = pieces[3]
 | 
						|
    head_idx = int(pieces[6])
 | 
						|
    label = pieces[7]
 | 
						|
    return id_, word, pos, head_idx, label
 | 
						|
 | 
						|
        
 | 
						|
def iter_data(paragraphs, tokenizer, gold_preproc=False):
 | 
						|
    for raw, tokenized, ids, words, tags, heads, labels in paragraphs:
 | 
						|
        assert len(words) == len(heads)
 | 
						|
        for words in tokenized:
 | 
						|
            sent_ids = ids[:len(words)]
 | 
						|
            sent_tags = tags[:len(words)]
 | 
						|
            sent_heads = heads[:len(words)]
 | 
						|
            sent_labels = labels[:len(words)]
 | 
						|
            sent_heads = _map_indices_to_tokens(sent_ids, sent_heads)
 | 
						|
            tokens = tokenizer.tokens_from_list(words)
 | 
						|
            yield tokens, sent_tags, sent_heads, sent_labels
 | 
						|
            ids = ids[len(words):]
 | 
						|
            tags = tags[len(words):]
 | 
						|
            heads = heads[len(words):]
 | 
						|
            labels = labels[len(words):]
 | 
						|
 | 
						|
 | 
						|
def _map_indices_to_tokens(ids, heads):
 | 
						|
    mapped = []
 | 
						|
    for head in heads:
 | 
						|
        if head not in ids:
 | 
						|
            mapped.append(None)
 | 
						|
        else:
 | 
						|
            mapped.append(ids.index(head))
 | 
						|
    return mapped
 | 
						|
 | 
						|
 | 
						|
 | 
						|
def evaluate(Language, dev_loc, model_dir):
 | 
						|
    global loss
 | 
						|
    nlp = Language()
 | 
						|
    n_corr = 0
 | 
						|
    pos_corr = 0
 | 
						|
    n_tokens = 0
 | 
						|
    total = 0
 | 
						|
    skipped = 0
 | 
						|
    loss = 0
 | 
						|
    with codecs.open(dev_loc, 'r', 'utf8') as file_:
 | 
						|
        paragraphs = read_gold(file_)
 | 
						|
    for tokens, tag_strs, heads, labels in iter_data(paragraphs, nlp.tokenizer):
 | 
						|
        assert len(tokens) == len(labels)
 | 
						|
        nlp.tagger.tag_from_strings(tokens, tag_strs)
 | 
						|
        nlp.parser(tokens)
 | 
						|
        for i, token in enumerate(tokens):
 | 
						|
            try:
 | 
						|
                pos_corr += token.tag_ == tag_strs[i]
 | 
						|
            except:
 | 
						|
                print i, token.orth_, token.tag
 | 
						|
                raise
 | 
						|
            n_tokens += 1
 | 
						|
            if heads[i] is None:
 | 
						|
                skipped += 1
 | 
						|
                continue
 | 
						|
            if is_punct_label(labels[i]):
 | 
						|
                continue
 | 
						|
            n_corr += token.head.i == heads[i]
 | 
						|
            total += 1
 | 
						|
    print loss, skipped, (loss+skipped + total)
 | 
						|
    print pos_corr / n_tokens
 | 
						|
    return float(n_corr) / (total + loss)
 | 
						|
 | 
						|
 | 
						|
def main(dev_loc, model_dir):
 | 
						|
    print evaluate(English, dev_loc, model_dir)
 | 
						|
    
 | 
						|
 | 
						|
if __name__ == '__main__':
 | 
						|
    plac.call(main)
 |