From 0035fd9efe9e459f75b52a9621a8b30bd39e2e11 Mon Sep 17 00:00:00 2001 From: ines Date: Thu, 23 Mar 2017 11:08:41 +0100 Subject: [PATCH] Add spacy train work in progress --- spacy/__main__.py | 32 +++++++++++++- spacy/cli/__init__.py | 1 + spacy/cli/train.py | 98 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 129 insertions(+), 2 deletions(-) create mode 100644 spacy/cli/train.py diff --git a/spacy/__main__.py b/spacy/__main__.py index f77a504e4..36dbb93d6 100644 --- a/spacy/__main__.py +++ b/spacy/__main__.py @@ -8,12 +8,14 @@ from spacy.cli import download as cli_download from spacy.cli import link as cli_link from spacy.cli import info as cli_info from spacy.cli import package as cli_package +from spacy.cli import train as cli_train +from spacy.cli import train_config as cli_train_config class CLI(object): """Command-line interface for spaCy""" - commands = ('download', 'link', 'info', 'package') + commands = ('download', 'link', 'info', 'package', 'train', 'train_config') @plac.annotations( model=("model to download (shortcut or model name)", "positional", None, str), @@ -61,7 +63,7 @@ class CLI(object): @plac.annotations( input_dir=("directory with model data", "positional", None, str), - output_dir=("output directory", "positional", None, str), + output_dir=("output parent directory", "positional", None, str), force=("force overwriting of existing folder in output directory", "flag", "f", bool) ) def package(self, input_dir, output_dir, force=False): @@ -74,6 +76,32 @@ class CLI(object): cli_package(input_dir, output_dir, force) + @plac.annotations( + lang=("language", "positional", None, str), + output_dir=("output directory", "positional", None, str), + train_data=("training data", "positional", None, str), + dev_data=("development data", "positional", None, str), + n_iter=("number of iterations", "flag", "n", int), + tagger=("train tagger", "flag", "t", bool), + parser=("train parser", "flag", "p", bool), + ner=("train NER", "flag", "n", bool) + ) + def train(self, lang, output_dir, train_data, dev_data, n_iter=15, tagger=True, + parser=True, ner=True): + """Train a model.""" + + cli_train(output_dir, train_data, dev_data, tagger, parser, ner) + + + @plac.annotations( + config=("config", "positional", None, str), + ) + def train_config(self, config): + """Train a model from config file.""" + + cli_train_config(config) + + def __missing__(self, name): print("\n Command %r does not exist\n" % name) diff --git a/spacy/cli/__init__.py b/spacy/cli/__init__.py index 2383e04b9..a4bc57ea9 100644 --- a/spacy/cli/__init__.py +++ b/spacy/cli/__init__.py @@ -2,3 +2,4 @@ from .download import download from .info import info from .link import link from .package import package +from .train import train, train_config diff --git a/spacy/cli/train.py b/spacy/cli/train.py new file mode 100644 index 000000000..58c30baf2 --- /dev/null +++ b/spacy/cli/train.py @@ -0,0 +1,98 @@ +# coding: utf8 +from __future__ import unicode_literals, division, print_function + + +import json +from pathlib import Path + +from ..scorer import Scorer +from ..tagger import Tagger +from ..syntax.parser import Parser +from ..gold import GoldParse, merge_sents +from ..gold import read_json_file as read_gold_json +from .. import util + + +def train(language, output_dir, train_data, dev_data, n_iter, tagger, parser, ner): + output_path = Path(output_dir) + train_path = Path(train_data) + dev_path = Path(dev_data) + check_dirs(output_path, data_path, dev_path) + + lang = util.get_lang_class(language) + parser_cfg = dict(locals()) + tagger_cfg = dict(locals()) + entity_cfg = dict(locals()) + parser_cfg['features'] = lang.Defaults.parser_features + entity_cfg['features'] = lang.Defaults.entity_features + gold_train = list(read_gold_json(train_path)) + gold_dev = list(read_gold_json(dev_path)) + + train_model(lang, gold_train, gold_dev, output_path, tagger_cfg, parser_cfg, + entity_cfg, n_iter) + scorer = evaluate(lang, list(read_gold_json(dev_loc)), output_path) + print_results(scorer) + + +def train_config(config): + config_path = Path(config) + if not config_path.is_file(): + util.sys_exit(config_path.as_posix(), title="Config file not found") + config = json.load(config_path) + for setting in []: + if setting not in config.keys(): + util.sys_exit("{s} not found in config file.".format(s=setting), + title="Missing setting") + + +def train_model(Language, train_data, dev_data, output_path, tagger_cfg, parser_cfg, + entity_cfg, n_iter): + print("Itn.\tN weight\tN feats\tUAS\tNER F.\tTag %\tToken %") + + with Language.train(output_path, train_data, tagger_cfg, parser_cfg, entity_cfg) as trainer: + loss = 0 + for itn, epoch in enumerate(trainer.epochs(n_iter, augment_data=None)): + for doc, gold in epoch: + trainer.update(doc, gold) + dev_scores = trainer.evaluate(dev_data) + print_progress(itn, trainer.nlp.parser.model.nr_weight, + trainer.nlp.parser.model.nr_active_feat, + **dev_scores.scores) + + +def evaluate(Language, gold_tuples, output_path): + print("Load parser", output_path) + nlp = Language(path=output_path) + scorer = Scorer() + for raw_text, sents in gold_tuples: + sents = merge_sents(sents) + for annot_tuples, brackets in sents: + if raw_text is None: + tokens = nlp.tokenizer.tokens_from_list(annot_tuples[1]) + nlp.tagger(tokens) + nlp.parser(tokens) + nlp.entity(tokens) + else: + tokens = nlp(raw_text) + gold = GoldParse.from_annot_tuples(tokens, annot_tuples) + scorer.score(tokens, gold) + return scorer + + +def check_dirs(input_path, train_path, dev_path): + if not output_path.exists(): + util.sys_exit(output_path.as_posix(), title="Output directory not found") + if not train_path.exists() and train_path.is_file(): + util.sys_exit(train_path.as_posix(), title="Training data not found") + + +def print_progress(itn, nr_weight, nr_active_feat, **scores): + tpl = '{:d}\t{:d}\t{:d}\t{uas:.3f}\t{ents_f:.3f}\t{tags_acc:.3f}\t{token_acc:.3f}' + print(tpl.format(itn, nr_weight, nr_active_feat, **scores)) + + +def print_results(scorer): + results = {'TOK': scorer.token_acc, 'POS': scorer.tags_acc, 'UAS': scorer.uas, + 'LAS': scorer.las, 'NER P': scorer.ents_p, 'NER R': scorer.ents_r, + 'NER F': scorer.ents_f} + util.print_table(results, title="Results")