* Tmp commit

This commit is contained in:
Matthew Honnibal 2015-02-23 14:05:04 -05:00
parent 10ed738df2
commit f5f15a1ef2

View File

@ -21,6 +21,7 @@ from spacy.en.pos import POS_TEMPLATES, POS_TAGS, setup_model_dir
from spacy.syntax.parser import GreedyParser from spacy.syntax.parser import GreedyParser
from spacy.syntax.parser import OracleError from spacy.syntax.parser import OracleError
from spacy.syntax.util import Config from spacy.syntax.util import Config
from spacy.syntax.conll import GoldParse
def is_punct_label(label): def is_punct_label(label):
@ -184,6 +185,7 @@ def get_labels(sents):
def train(Language, paragraphs, model_dir, n_iter=15, feat_set=u'basic', seed=0, def train(Language, paragraphs, model_dir, n_iter=15, feat_set=u'basic', seed=0,
gold_preproc=False, force_gold=False): gold_preproc=False, force_gold=False):
print "Setup model dir"
dep_model_dir = path.join(model_dir, 'deps') dep_model_dir = path.join(model_dir, 'deps')
pos_model_dir = path.join(model_dir, 'pos') pos_model_dir = path.join(model_dir, 'pos')
if path.exists(dep_model_dir): if path.exists(dep_model_dir):
@ -198,7 +200,6 @@ def train(Language, paragraphs, model_dir, n_iter=15, feat_set=u'basic', seed=0,
labels = Language.ParserTransitionSystem.get_labels(gold_sents) labels = Language.ParserTransitionSystem.get_labels(gold_sents)
Config.write(dep_model_dir, 'config', features=feat_set, seed=seed, Config.write(dep_model_dir, 'config', features=feat_set, seed=seed,
labels=labels) labels=labels)
nlp = Language() nlp = Language()
for itn in range(n_iter): for itn in range(n_iter):
@ -206,16 +207,16 @@ def train(Language, paragraphs, model_dir, n_iter=15, feat_set=u'basic', seed=0,
pos_corr = 0 pos_corr = 0
n_tokens = 0 n_tokens = 0
for gold_sent in gold_sents: for gold_sent in gold_sents:
tokens = nlp.tokenizer(gold_sent.raw) tokens = nlp.tokenizer(gold_sent.raw_text)
gold_sent.align_to_tokens(tokens) gold_sent.align_to_tokens(tokens, nlp.parser.moves.label_ids)
nlp.tagger(tokens) nlp.tagger(tokens)
heads_corr += nlp.parser.train(tokens, gold_sent, force_gold=force_gold) heads_corr += nlp.parser.train(tokens, gold_sent, force_gold=force_gold)
pos_corr += nlp.tagger.train(tokens, gold_parse.tags) pos_corr += nlp.tagger.train(tokens, gold_sent.tags)
n_tokens += len(tokens) n_tokens += len(tokens)
acc = float(heads_corr) / n_tokens acc = float(heads_corr) / n_tokens
pos_acc = float(pos_corr) / n_tokens pos_acc = float(pos_corr) / n_tokens
print '%d: ' % itn, '%.3f' % acc, '%.3f' % pos_acc print '%d: ' % itn, '%.3f' % acc, '%.3f' % pos_acc
random.shuffle(paragraphs) random.shuffle(gold_sents)
nlp.parser.model.end_training() nlp.parser.model.end_training()
nlp.tagger.model.end_training() nlp.tagger.model.end_training()
return acc return acc
@ -257,10 +258,16 @@ def evaluate(Language, dev_loc, model_dir, gold_preproc=False):
return float(uas_corr) / (total + loss) return float(uas_corr) / (total + loss)
def main(train_loc, dev_loc, model_dir): @plac.annotations(
train(English, read_docparse_gold(train_loc), model_dir, train_loc=("Training file location",),
dev_loc=("Dev. file location",),
model_dir=("Location of output model directory",),
n_sents=("Number of training sentences", "option", "n", int)
)
def main(train_loc, dev_loc, model_dir, n_sents=0):
train(English, read_gold(train_loc, n=n_sents), model_dir,
gold_preproc=False, force_gold=False) gold_preproc=False, force_gold=False)
print evaluate(English, read_docparse_gold(dev_loc), model_dir, gold_preproc=False) print evaluate(English, read_gold(dev_loc), model_dir, gold_preproc=False)
if __name__ == '__main__': if __name__ == '__main__':