* Tmp commit of train, while I move to better alignment in gold standard

This commit is contained in:
Matthew Honnibal 2015-05-23 17:21:25 +02:00
parent bdaddc4103
commit f35503018e

View File

@ -11,6 +11,7 @@ import random
import plac
import cProfile
import pstats
import re
import spacy.util
from spacy.en import English
@ -51,11 +52,10 @@ def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic', seed=0
gold_tuples = gold_tuples[:n_sents]
nlp = Language(data_dir=model_dir)
print "Itn.\tUAS\tNER F.\tTag %"
print "Itn.\tUAS\tNER F.\tTag %\tToken %"
for itn in range(n_iter):
scorer = Scorer()
for raw_text, segmented_text, annot_tuples, ctnt in gold_tuples:
# Eval before train
tokens = nlp(raw_text, merge_mwes=False)
gold = GoldParse(tokens, annot_tuples)
scorer.score(tokens, gold, verbose=False)
@ -67,12 +67,18 @@ def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic', seed=0
for tokens in sents:
gold = GoldParse(tokens, annot_tuples)
nlp.tagger(tokens)
try:
nlp.parser.train(tokens, gold)
except AssertionError:
# TODO: Do something about non-projective sentences
continue
if gold.ents:
nlp.entity.train(tokens, gold)
nlp.tagger.train(tokens, gold.tags)
print '%d:\t%.3f\t%.3f\t%.3f' % (itn, scorer.uas, scorer.ents_f, scorer.tags_acc)
print '%d:\t%.3f\t%.3f\t%.3f\t%.3f' % (itn, scorer.uas, scorer.ents_f,
scorer.tags_acc,
scorer.token_acc)
random.shuffle(gold_tuples)
nlp.parser.model.end_training()
nlp.entity.model.end_training()
@ -106,6 +112,10 @@ def write_parses(Language, dev_loc, model_dir, out_loc):
def get_sents(json_dir, section):
if path.exists(path.join(json_dir, section + '.json')):
for sent in read_json_file(path.join(json_dir, section + '.json')):
yield sent
else:
if section == 'train':
file_range = range(2, 22)
elif section == 'dev':
@ -137,7 +147,7 @@ def main(json_dir, model_dir, n_sents=0, out_loc="", verbose=False,
write_parses(English, dev_loc, model_dir, out_loc)
scorer = evaluate(English, list(get_sents(json_dir, 'dev')),
model_dir, gold_preproc=False, verbose=verbose)
print 'TOK', scorer.mistokened
print 'TOK', 100-scorer.token_acc
print 'POS', scorer.tags_acc
print 'UAS', scorer.uas
print 'LAS', scorer.las