* Allow gold tokenization training, for debugging

This commit is contained in:
Matthew Honnibal 2015-03-08 00:17:12 -05:00
parent 8da53cbe3c
commit 7a1a333f04

View File

@ -21,7 +21,7 @@ from spacy.en.pos import POS_TEMPLATES, POS_TAGS, setup_model_dir
from spacy.syntax.parser import GreedyParser from spacy.syntax.parser import GreedyParser
from spacy.syntax.parser import OracleError from spacy.syntax.parser import OracleError
from spacy.syntax.util import Config from spacy.syntax.util import Config
from spacy.syntax.conll import GoldParse from spacy.syntax.conll import GoldParse, is_punct_label
def is_punct_label(label): def is_punct_label(label):
@ -206,15 +206,22 @@ def train(Language, paragraphs, model_dir, n_iter=15, feat_set=u'basic', seed=0,
heads_corr = 0 heads_corr = 0
pos_corr = 0 pos_corr = 0
n_tokens = 0 n_tokens = 0
n_all_tokens = 0
for gold_sent in gold_sents: for gold_sent in gold_sents:
if gold_preproc:
#print ' '.join(gold_sent.words)
tokens = nlp.tokenizer.tokens_from_list(gold_sent.words)
gold_sent.map_heads(nlp.parser.moves.label_ids)
else:
tokens = nlp.tokenizer(gold_sent.raw_text) tokens = nlp.tokenizer(gold_sent.raw_text)
gold_sent.align_to_tokens(tokens, nlp.parser.moves.label_ids) gold_sent.align_to_tokens(tokens, nlp.parser.moves.label_ids)
nlp.tagger(tokens) nlp.tagger(tokens)
heads_corr += nlp.parser.train(tokens, gold_sent, force_gold=force_gold) heads_corr += nlp.parser.train(tokens, gold_sent, force_gold=force_gold)
pos_corr += nlp.tagger.train(tokens, gold_sent.tags) pos_corr += nlp.tagger.train(tokens, gold_sent.tags)
n_tokens += len(tokens) n_tokens += gold_sent.n_non_punct
n_all_tokens += len(tokens)
acc = float(heads_corr) / n_tokens acc = float(heads_corr) / n_tokens
pos_acc = float(pos_corr) / n_tokens pos_acc = float(pos_corr) / n_all_tokens
print '%d: ' % itn, '%.3f' % acc, '%.3f' % pos_acc print '%d: ' % itn, '%.3f' % acc, '%.3f' % pos_acc
random.shuffle(gold_sents) random.shuffle(gold_sents)
nlp.parser.model.end_training() nlp.parser.model.end_training()
@ -241,21 +248,26 @@ def evaluate(Language, dev_loc, model_dir, gold_preproc=False):
nlp.tagger(tokens) nlp.tagger(tokens)
nlp.parser(tokens) nlp.parser(tokens)
for i, token in enumerate(tokens): for i, token in enumerate(tokens):
pos_corr += token.tag_ == tag_strs[i] pos_corr += token.tag_ == gold_sent.tags[i]
n_tokens += 1 n_tokens += 1
if heads[i] is None: if gold_sent.heads[i] is None:
skipped += 1 skipped += 1
continue continue
if is_punct_label(labels[i]): #print i, token.orth_, token.head.i, gold_sent.py_heads[i], gold_sent.labels[i],
continue #print gold_sent.is_correct(i, token.head.i)
uas_corr += token.head.i == heads[i] if gold_sent.labels[i] != 'P':
las_corr += token.head.i == heads[i] and token.dep_ == labels[i] n_corr += gold_sent.is_correct(i, token.head.i)
#print token.orth_, token.head.orth_, token.dep_, labels[i]
total += 1 total += 1
print loss, skipped, (loss+skipped + total) print loss, skipped, (loss+skipped + total)
print pos_corr / n_tokens print pos_corr / n_tokens
print float(las_corr) / (total + loss) return float(n_corr) / (total + loss)
return float(uas_corr) / (total + loss)
def read_gold(loc, n=0):
sent_strs = open(loc).read().strip().split('\n\n')
if n == 0:
n = len(sent_strs)
return [GoldParse.from_docparse(sent) for sent in sent_strs[:n]]
@plac.annotations( @plac.annotations(
@ -265,8 +277,8 @@ def evaluate(Language, dev_loc, model_dir, gold_preproc=False):
n_sents=("Number of training sentences", "option", "n", int) n_sents=("Number of training sentences", "option", "n", int)
) )
def main(train_loc, dev_loc, model_dir, n_sents=0): def main(train_loc, dev_loc, model_dir, n_sents=0):
train(English, read_gold(train_loc, n=n_sents), model_dir, #train(English, read_gold(train_loc, n=n_sents), model_dir,
gold_preproc=False, force_gold=False) # gold_preproc=False, force_gold=False)
print evaluate(English, read_gold(dev_loc), model_dir, gold_preproc=False) print evaluate(English, read_gold(dev_loc), model_dir, gold_preproc=False)