mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-23 15:54:13 +03:00
131 lines
3.5 KiB
Python
131 lines
3.5 KiB
Python
#!/usr/bin/env python
|
|
from __future__ import division
|
|
from __future__ import unicode_literals
|
|
|
|
import os
|
|
from os import path
|
|
import shutil
|
|
import codecs
|
|
import random
|
|
import time
|
|
import gzip
|
|
|
|
import plac
|
|
import cProfile
|
|
import pstats
|
|
|
|
import spacy.util
|
|
from spacy.en import English
|
|
from spacy.en.pos import POS_TEMPLATES, POS_TAGS, setup_model_dir
|
|
|
|
from spacy.syntax.parser import GreedyParser
|
|
from spacy.syntax.parser import OracleError
|
|
from spacy.syntax.util import Config
|
|
|
|
|
|
def is_punct_label(label):
|
|
return label == 'P' or label.lower() == 'punct'
|
|
|
|
|
|
def read_gold(file_):
|
|
"""Read a standard CoNLL/MALT-style format"""
|
|
sents = []
|
|
for sent_str in file_.read().strip().split('\n\n'):
|
|
ids = []
|
|
words = []
|
|
heads = []
|
|
labels = []
|
|
tags = []
|
|
for i, line in enumerate(sent_str.split('\n')):
|
|
id_, word, pos_string, head_idx, label = _parse_line(line)
|
|
words.append(word)
|
|
if head_idx == -1:
|
|
head_idx = i
|
|
ids.append(id_)
|
|
heads.append(head_idx)
|
|
labels.append(label)
|
|
tags.append(pos_string)
|
|
text = ' '.join(words)
|
|
sents.append((text, [words], ids, words, tags, heads, labels))
|
|
return sents
|
|
|
|
|
|
def _parse_line(line):
|
|
pieces = line.split()
|
|
id_ = int(pieces[0])
|
|
word = pieces[1]
|
|
pos = pieces[3]
|
|
head_idx = int(pieces[6])
|
|
label = pieces[7]
|
|
return id_, word, pos, head_idx, label
|
|
|
|
|
|
def iter_data(paragraphs, tokenizer, gold_preproc=False):
|
|
for raw, tokenized, ids, words, tags, heads, labels in paragraphs:
|
|
assert len(words) == len(heads)
|
|
for words in tokenized:
|
|
sent_ids = ids[:len(words)]
|
|
sent_tags = tags[:len(words)]
|
|
sent_heads = heads[:len(words)]
|
|
sent_labels = labels[:len(words)]
|
|
sent_heads = _map_indices_to_tokens(sent_ids, sent_heads)
|
|
tokens = tokenizer.tokens_from_list(words)
|
|
yield tokens, sent_tags, sent_heads, sent_labels
|
|
ids = ids[len(words):]
|
|
tags = tags[len(words):]
|
|
heads = heads[len(words):]
|
|
labels = labels[len(words):]
|
|
|
|
|
|
def _map_indices_to_tokens(ids, heads):
|
|
mapped = []
|
|
for head in heads:
|
|
if head not in ids:
|
|
mapped.append(None)
|
|
else:
|
|
mapped.append(ids.index(head))
|
|
return mapped
|
|
|
|
|
|
|
|
def evaluate(Language, dev_loc, model_dir):
|
|
global loss
|
|
nlp = Language()
|
|
n_corr = 0
|
|
pos_corr = 0
|
|
n_tokens = 0
|
|
total = 0
|
|
skipped = 0
|
|
loss = 0
|
|
with codecs.open(dev_loc, 'r', 'utf8') as file_:
|
|
paragraphs = read_gold(file_)
|
|
for tokens, tag_strs, heads, labels in iter_data(paragraphs, nlp.tokenizer):
|
|
assert len(tokens) == len(labels)
|
|
nlp.tagger.tag_from_strings(tokens, tag_strs)
|
|
nlp.parser(tokens)
|
|
for i, token in enumerate(tokens):
|
|
try:
|
|
pos_corr += token.tag_ == tag_strs[i]
|
|
except:
|
|
print i, token.orth_, token.tag
|
|
raise
|
|
n_tokens += 1
|
|
if heads[i] is None:
|
|
skipped += 1
|
|
continue
|
|
if is_punct_label(labels[i]):
|
|
continue
|
|
n_corr += token.head.i == heads[i]
|
|
total += 1
|
|
print loss, skipped, (loss+skipped + total)
|
|
print pos_corr / n_tokens
|
|
return float(n_corr) / (total + loss)
|
|
|
|
|
|
def main(dev_loc, model_dir):
|
|
print evaluate(English, dev_loc, model_dir)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
plac.call(main)
|