Add spacy train work in progress

This commit is contained in:
ines 2017-03-23 11:08:41 +01:00
parent d5ebf583a4
commit 0035fd9efe
3 changed files with 129 additions and 2 deletions

View File

@ -8,12 +8,14 @@ from spacy.cli import download as cli_download
from spacy.cli import link as cli_link from spacy.cli import link as cli_link
from spacy.cli import info as cli_info from spacy.cli import info as cli_info
from spacy.cli import package as cli_package from spacy.cli import package as cli_package
from spacy.cli import train as cli_train
from spacy.cli import train_config as cli_train_config
class CLI(object): class CLI(object):
"""Command-line interface for spaCy""" """Command-line interface for spaCy"""
commands = ('download', 'link', 'info', 'package') commands = ('download', 'link', 'info', 'package', 'train', 'train_config')
@plac.annotations( @plac.annotations(
model=("model to download (shortcut or model name)", "positional", None, str), model=("model to download (shortcut or model name)", "positional", None, str),
@ -61,7 +63,7 @@ class CLI(object):
@plac.annotations( @plac.annotations(
input_dir=("directory with model data", "positional", None, str), input_dir=("directory with model data", "positional", None, str),
output_dir=("output directory", "positional", None, str), output_dir=("output parent directory", "positional", None, str),
force=("force overwriting of existing folder in output directory", "flag", "f", bool) force=("force overwriting of existing folder in output directory", "flag", "f", bool)
) )
def package(self, input_dir, output_dir, force=False): def package(self, input_dir, output_dir, force=False):
@ -74,6 +76,32 @@ class CLI(object):
cli_package(input_dir, output_dir, force) cli_package(input_dir, output_dir, force)
@plac.annotations(
lang=("language", "positional", None, str),
output_dir=("output directory", "positional", None, str),
train_data=("training data", "positional", None, str),
dev_data=("development data", "positional", None, str),
n_iter=("number of iterations", "flag", "n", int),
tagger=("train tagger", "flag", "t", bool),
parser=("train parser", "flag", "p", bool),
ner=("train NER", "flag", "n", bool)
)
def train(self, lang, output_dir, train_data, dev_data, n_iter=15, tagger=True,
parser=True, ner=True):
"""Train a model."""
cli_train(output_dir, train_data, dev_data, tagger, parser, ner)
@plac.annotations(
config=("config", "positional", None, str),
)
def train_config(self, config):
"""Train a model from config file."""
cli_train_config(config)
def __missing__(self, name): def __missing__(self, name):
print("\n Command %r does not exist\n" % name) print("\n Command %r does not exist\n" % name)

View File

@ -2,3 +2,4 @@ from .download import download
from .info import info from .info import info
from .link import link from .link import link
from .package import package from .package import package
from .train import train, train_config

98
spacy/cli/train.py Normal file
View File

@ -0,0 +1,98 @@
# coding: utf8
from __future__ import unicode_literals, division, print_function
import json
from pathlib import Path
from ..scorer import Scorer
from ..tagger import Tagger
from ..syntax.parser import Parser
from ..gold import GoldParse, merge_sents
from ..gold import read_json_file as read_gold_json
from .. import util
def train(language, output_dir, train_data, dev_data, n_iter, tagger, parser, ner):
output_path = Path(output_dir)
train_path = Path(train_data)
dev_path = Path(dev_data)
check_dirs(output_path, data_path, dev_path)
lang = util.get_lang_class(language)
parser_cfg = dict(locals())
tagger_cfg = dict(locals())
entity_cfg = dict(locals())
parser_cfg['features'] = lang.Defaults.parser_features
entity_cfg['features'] = lang.Defaults.entity_features
gold_train = list(read_gold_json(train_path))
gold_dev = list(read_gold_json(dev_path))
train_model(lang, gold_train, gold_dev, output_path, tagger_cfg, parser_cfg,
entity_cfg, n_iter)
scorer = evaluate(lang, list(read_gold_json(dev_loc)), output_path)
print_results(scorer)
def train_config(config):
config_path = Path(config)
if not config_path.is_file():
util.sys_exit(config_path.as_posix(), title="Config file not found")
config = json.load(config_path)
for setting in []:
if setting not in config.keys():
util.sys_exit("{s} not found in config file.".format(s=setting),
title="Missing setting")
def train_model(Language, train_data, dev_data, output_path, tagger_cfg, parser_cfg,
entity_cfg, n_iter):
print("Itn.\tN weight\tN feats\tUAS\tNER F.\tTag %\tToken %")
with Language.train(output_path, train_data, tagger_cfg, parser_cfg, entity_cfg) as trainer:
loss = 0
for itn, epoch in enumerate(trainer.epochs(n_iter, augment_data=None)):
for doc, gold in epoch:
trainer.update(doc, gold)
dev_scores = trainer.evaluate(dev_data)
print_progress(itn, trainer.nlp.parser.model.nr_weight,
trainer.nlp.parser.model.nr_active_feat,
**dev_scores.scores)
def evaluate(Language, gold_tuples, output_path):
print("Load parser", output_path)
nlp = Language(path=output_path)
scorer = Scorer()
for raw_text, sents in gold_tuples:
sents = merge_sents(sents)
for annot_tuples, brackets in sents:
if raw_text is None:
tokens = nlp.tokenizer.tokens_from_list(annot_tuples[1])
nlp.tagger(tokens)
nlp.parser(tokens)
nlp.entity(tokens)
else:
tokens = nlp(raw_text)
gold = GoldParse.from_annot_tuples(tokens, annot_tuples)
scorer.score(tokens, gold)
return scorer
def check_dirs(input_path, train_path, dev_path):
if not output_path.exists():
util.sys_exit(output_path.as_posix(), title="Output directory not found")
if not train_path.exists() and train_path.is_file():
util.sys_exit(train_path.as_posix(), title="Training data not found")
def print_progress(itn, nr_weight, nr_active_feat, **scores):
tpl = '{:d}\t{:d}\t{:d}\t{uas:.3f}\t{ents_f:.3f}\t{tags_acc:.3f}\t{token_acc:.3f}'
print(tpl.format(itn, nr_weight, nr_active_feat, **scores))
def print_results(scorer):
results = {'TOK': scorer.token_acc, 'POS': scorer.tags_acc, 'UAS': scorer.uas,
'LAS': scorer.las, 'NER P': scorer.ents_p, 'NER R': scorer.ents_r,
'NER F': scorer.ents_f}
util.print_table(results, title="Results")