Update train.py

This commit is contained in:
Matthew Honnibal 2016-10-13 03:23:48 +02:00
parent 41f88ce938
commit bd1bfcca61

View File

@ -17,6 +17,7 @@ import spacy.util
from spacy.syntax.util import Config from spacy.syntax.util import Config
from spacy.gold import read_json_file from spacy.gold import read_json_file
from spacy.gold import GoldParse from spacy.gold import GoldParse
from spacy.gold import merge_sents
from spacy.scorer import Scorer from spacy.scorer import Scorer
@ -63,22 +64,6 @@ def score_model(scorer, nlp, raw_text, annot_tuples, verbose=False):
scorer.score(tokens, gold, verbose=verbose) scorer.score(tokens, gold, verbose=verbose)
def _merge_sents(sents):
m_deps = [[], [], [], [], [], []]
m_brackets = []
i = 0
for (ids, words, tags, heads, labels, ner), brackets in sents:
m_deps[0].extend(id_ + i for id_ in ids)
m_deps[1].extend(words)
m_deps[2].extend(tags)
m_deps[3].extend(head + i for head in heads)
m_deps[4].extend(labels)
m_deps[5].extend(ner)
m_brackets.extend((b['first'] + i, b['last'] + i, b['label']) for b in brackets)
i += len(ids)
return [(m_deps, m_brackets)]
def train(Language, train_data, dev_data, model_dir, tagger_cfg, parser_cfg, entity_cfg, def train(Language, train_data, dev_data, model_dir, tagger_cfg, parser_cfg, entity_cfg,
n_iter=15, seed=0, gold_preproc=False, n_sents=0, corruption_level=0): n_iter=15, seed=0, gold_preproc=False, n_sents=0, corruption_level=0):
print("Itn.\tP.Loss\tUAS\tNER F.\tTag %\tToken %") print("Itn.\tP.Loss\tUAS\tNER F.\tTag %\tToken %")
@ -86,10 +71,11 @@ def train(Language, train_data, dev_data, model_dir, tagger_cfg, parser_cfg, ent
with Language.train(model_dir, train_data, with Language.train(model_dir, train_data,
tagger_cfg, parser_cfg, entity_cfg) as trainer: tagger_cfg, parser_cfg, entity_cfg) as trainer:
loss = 0 loss = 0
for itn, epoch in enumerate(trainer.epochs(n_iter, augment_data=None)): for itn, epoch in enumerate(trainer.epochs(n_iter, gold_preproc=gold_preproc,
augment_data=None)):
for doc, gold in epoch: for doc, gold in epoch:
trainer.update(doc, gold) trainer.update(doc, gold)
dev_scores = trainer.evaluate(dev_data) dev_scores = trainer.evaluate(dev_data, gold_preproc=gold_preproc)
print(format_str.format(itn, loss, **dev_scores.scores)) print(format_str.format(itn, loss, **dev_scores.scores))
@ -105,7 +91,7 @@ def evaluate(Language, gold_tuples, model_dir, gold_preproc=False, verbose=False
if gold_preproc: if gold_preproc:
raw_text = None raw_text = None
else: else:
sents = _merge_sents(sents) sents = merge_sents(sents)
for annot_tuples, brackets in sents: for annot_tuples, brackets in sents:
if raw_text is None: if raw_text is None:
tokens = nlp.tokenizer.tokens_from_list(annot_tuples[1]) tokens = nlp.tokenizer.tokens_from_list(annot_tuples[1])