* Allow gold tokenization training, for debugging

This commit is contained in:
Matthew Honnibal 2015-03-08 00:17:12 -05:00
parent 8da53cbe3c
commit 7a1a333f04

View File

@ -21,7 +21,7 @@ from spacy.en.pos import POS_TEMPLATES, POS_TAGS, setup_model_dir
from spacy.syntax.parser import GreedyParser
from spacy.syntax.parser import OracleError
from spacy.syntax.util import Config
from spacy.syntax.conll import GoldParse
from spacy.syntax.conll import GoldParse, is_punct_label
def is_punct_label(label):
@ -206,15 +206,22 @@ def train(Language, paragraphs, model_dir, n_iter=15, feat_set=u'basic', seed=0,
heads_corr = 0
pos_corr = 0
n_tokens = 0
n_all_tokens = 0
for gold_sent in gold_sents:
tokens = nlp.tokenizer(gold_sent.raw_text)
gold_sent.align_to_tokens(tokens, nlp.parser.moves.label_ids)
if gold_preproc:
#print ' '.join(gold_sent.words)
tokens = nlp.tokenizer.tokens_from_list(gold_sent.words)
gold_sent.map_heads(nlp.parser.moves.label_ids)
else:
tokens = nlp.tokenizer(gold_sent.raw_text)
gold_sent.align_to_tokens(tokens, nlp.parser.moves.label_ids)
nlp.tagger(tokens)
heads_corr += nlp.parser.train(tokens, gold_sent, force_gold=force_gold)
pos_corr += nlp.tagger.train(tokens, gold_sent.tags)
n_tokens += len(tokens)
n_tokens += gold_sent.n_non_punct
n_all_tokens += len(tokens)
acc = float(heads_corr) / n_tokens
pos_acc = float(pos_corr) / n_tokens
pos_acc = float(pos_corr) / n_all_tokens
print '%d: ' % itn, '%.3f' % acc, '%.3f' % pos_acc
random.shuffle(gold_sents)
nlp.parser.model.end_training()
@ -241,21 +248,26 @@ def evaluate(Language, dev_loc, model_dir, gold_preproc=False):
nlp.tagger(tokens)
nlp.parser(tokens)
for i, token in enumerate(tokens):
pos_corr += token.tag_ == tag_strs[i]
pos_corr += token.tag_ == gold_sent.tags[i]
n_tokens += 1
if heads[i] is None:
if gold_sent.heads[i] is None:
skipped += 1
continue
if is_punct_label(labels[i]):
continue
uas_corr += token.head.i == heads[i]
las_corr += token.head.i == heads[i] and token.dep_ == labels[i]
#print token.orth_, token.head.orth_, token.dep_, labels[i]
total += 1
#print i, token.orth_, token.head.i, gold_sent.py_heads[i], gold_sent.labels[i],
#print gold_sent.is_correct(i, token.head.i)
if gold_sent.labels[i] != 'P':
n_corr += gold_sent.is_correct(i, token.head.i)
total += 1
print loss, skipped, (loss+skipped + total)
print pos_corr / n_tokens
print float(las_corr) / (total + loss)
return float(uas_corr) / (total + loss)
return float(n_corr) / (total + loss)
def read_gold(loc, n=0):
sent_strs = open(loc).read().strip().split('\n\n')
if n == 0:
n = len(sent_strs)
return [GoldParse.from_docparse(sent) for sent in sent_strs[:n]]
@plac.annotations(
@ -265,8 +277,8 @@ def evaluate(Language, dev_loc, model_dir, gold_preproc=False):
n_sents=("Number of training sentences", "option", "n", int)
)
def main(train_loc, dev_loc, model_dir, n_sents=0):
train(English, read_gold(train_loc, n=n_sents), model_dir,
gold_preproc=False, force_gold=False)
#train(English, read_gold(train_loc, n=n_sents), model_dir,
# gold_preproc=False, force_gold=False)
print evaluate(English, read_gold(dev_loc), model_dir, gold_preproc=False)