* Tmp commit

This commit is contained in:
Matthew Honnibal 2015-02-23 14:05:04 -05:00
parent 10ed738df2
commit f5f15a1ef2

View File

@ -21,6 +21,7 @@ from spacy.en.pos import POS_TEMPLATES, POS_TAGS, setup_model_dir
from spacy.syntax.parser import GreedyParser
from spacy.syntax.parser import OracleError
from spacy.syntax.util import Config
from spacy.syntax.conll import GoldParse
def is_punct_label(label):
@ -184,6 +185,7 @@ def get_labels(sents):
def train(Language, paragraphs, model_dir, n_iter=15, feat_set=u'basic', seed=0,
gold_preproc=False, force_gold=False):
print "Setup model dir"
dep_model_dir = path.join(model_dir, 'deps')
pos_model_dir = path.join(model_dir, 'pos')
if path.exists(dep_model_dir):
@ -198,7 +200,6 @@ def train(Language, paragraphs, model_dir, n_iter=15, feat_set=u'basic', seed=0,
labels = Language.ParserTransitionSystem.get_labels(gold_sents)
Config.write(dep_model_dir, 'config', features=feat_set, seed=seed,
labels=labels)
nlp = Language()
for itn in range(n_iter):
@ -206,16 +207,16 @@ def train(Language, paragraphs, model_dir, n_iter=15, feat_set=u'basic', seed=0,
pos_corr = 0
n_tokens = 0
for gold_sent in gold_sents:
tokens = nlp.tokenizer(gold_sent.raw)
gold_sent.align_to_tokens(tokens)
tokens = nlp.tokenizer(gold_sent.raw_text)
gold_sent.align_to_tokens(tokens, nlp.parser.moves.label_ids)
nlp.tagger(tokens)
heads_corr += nlp.parser.train(tokens, gold_sent, force_gold=force_gold)
pos_corr += nlp.tagger.train(tokens, gold_parse.tags)
pos_corr += nlp.tagger.train(tokens, gold_sent.tags)
n_tokens += len(tokens)
acc = float(heads_corr) / n_tokens
pos_acc = float(pos_corr) / n_tokens
print '%d: ' % itn, '%.3f' % acc, '%.3f' % pos_acc
random.shuffle(paragraphs)
random.shuffle(gold_sents)
nlp.parser.model.end_training()
nlp.tagger.model.end_training()
return acc
@ -257,10 +258,16 @@ def evaluate(Language, dev_loc, model_dir, gold_preproc=False):
return float(uas_corr) / (total + loss)
def main(train_loc, dev_loc, model_dir):
train(English, read_docparse_gold(train_loc), model_dir,
@plac.annotations(
train_loc=("Training file location",),
dev_loc=("Dev. file location",),
model_dir=("Location of output model directory",),
n_sents=("Number of training sentences", "option", "n", int)
)
def main(train_loc, dev_loc, model_dir, n_sents=0):
train(English, read_gold(train_loc, n=n_sents), model_dir,
gold_preproc=False, force_gold=False)
print evaluate(English, read_docparse_gold(dev_loc), model_dir, gold_preproc=False)
print evaluate(English, read_gold(dev_loc), model_dir, gold_preproc=False)
if __name__ == '__main__':