mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 17:06:29 +03:00
Add spacy train work in progress
This commit is contained in:
parent
d5ebf583a4
commit
0035fd9efe
|
@ -8,12 +8,14 @@ from spacy.cli import download as cli_download
|
||||||
from spacy.cli import link as cli_link
|
from spacy.cli import link as cli_link
|
||||||
from spacy.cli import info as cli_info
|
from spacy.cli import info as cli_info
|
||||||
from spacy.cli import package as cli_package
|
from spacy.cli import package as cli_package
|
||||||
|
from spacy.cli import train as cli_train
|
||||||
|
from spacy.cli import train_config as cli_train_config
|
||||||
|
|
||||||
|
|
||||||
class CLI(object):
|
class CLI(object):
|
||||||
"""Command-line interface for spaCy"""
|
"""Command-line interface for spaCy"""
|
||||||
|
|
||||||
commands = ('download', 'link', 'info', 'package')
|
commands = ('download', 'link', 'info', 'package', 'train', 'train_config')
|
||||||
|
|
||||||
@plac.annotations(
|
@plac.annotations(
|
||||||
model=("model to download (shortcut or model name)", "positional", None, str),
|
model=("model to download (shortcut or model name)", "positional", None, str),
|
||||||
|
@ -61,7 +63,7 @@ class CLI(object):
|
||||||
|
|
||||||
@plac.annotations(
|
@plac.annotations(
|
||||||
input_dir=("directory with model data", "positional", None, str),
|
input_dir=("directory with model data", "positional", None, str),
|
||||||
output_dir=("output directory", "positional", None, str),
|
output_dir=("output parent directory", "positional", None, str),
|
||||||
force=("force overwriting of existing folder in output directory", "flag", "f", bool)
|
force=("force overwriting of existing folder in output directory", "flag", "f", bool)
|
||||||
)
|
)
|
||||||
def package(self, input_dir, output_dir, force=False):
|
def package(self, input_dir, output_dir, force=False):
|
||||||
|
@ -74,6 +76,32 @@ class CLI(object):
|
||||||
cli_package(input_dir, output_dir, force)
|
cli_package(input_dir, output_dir, force)
|
||||||
|
|
||||||
|
|
||||||
|
@plac.annotations(
|
||||||
|
lang=("language", "positional", None, str),
|
||||||
|
output_dir=("output directory", "positional", None, str),
|
||||||
|
train_data=("training data", "positional", None, str),
|
||||||
|
dev_data=("development data", "positional", None, str),
|
||||||
|
n_iter=("number of iterations", "flag", "n", int),
|
||||||
|
tagger=("train tagger", "flag", "t", bool),
|
||||||
|
parser=("train parser", "flag", "p", bool),
|
||||||
|
ner=("train NER", "flag", "n", bool)
|
||||||
|
)
|
||||||
|
def train(self, lang, output_dir, train_data, dev_data, n_iter=15, tagger=True,
|
||||||
|
parser=True, ner=True):
|
||||||
|
"""Train a model."""
|
||||||
|
|
||||||
|
cli_train(output_dir, train_data, dev_data, tagger, parser, ner)
|
||||||
|
|
||||||
|
|
||||||
|
@plac.annotations(
|
||||||
|
config=("config", "positional", None, str),
|
||||||
|
)
|
||||||
|
def train_config(self, config):
|
||||||
|
"""Train a model from config file."""
|
||||||
|
|
||||||
|
cli_train_config(config)
|
||||||
|
|
||||||
|
|
||||||
def __missing__(self, name):
|
def __missing__(self, name):
|
||||||
print("\n Command %r does not exist\n" % name)
|
print("\n Command %r does not exist\n" % name)
|
||||||
|
|
||||||
|
|
|
@ -2,3 +2,4 @@ from .download import download
|
||||||
from .info import info
|
from .info import info
|
||||||
from .link import link
|
from .link import link
|
||||||
from .package import package
|
from .package import package
|
||||||
|
from .train import train, train_config
|
||||||
|
|
98
spacy/cli/train.py
Normal file
98
spacy/cli/train.py
Normal file
|
@ -0,0 +1,98 @@
|
||||||
|
# coding: utf8
|
||||||
|
from __future__ import unicode_literals, division, print_function
|
||||||
|
|
||||||
|
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from ..scorer import Scorer
|
||||||
|
from ..tagger import Tagger
|
||||||
|
from ..syntax.parser import Parser
|
||||||
|
from ..gold import GoldParse, merge_sents
|
||||||
|
from ..gold import read_json_file as read_gold_json
|
||||||
|
from .. import util
|
||||||
|
|
||||||
|
|
||||||
|
def train(language, output_dir, train_data, dev_data, n_iter, tagger, parser, ner):
|
||||||
|
output_path = Path(output_dir)
|
||||||
|
train_path = Path(train_data)
|
||||||
|
dev_path = Path(dev_data)
|
||||||
|
check_dirs(output_path, data_path, dev_path)
|
||||||
|
|
||||||
|
lang = util.get_lang_class(language)
|
||||||
|
parser_cfg = dict(locals())
|
||||||
|
tagger_cfg = dict(locals())
|
||||||
|
entity_cfg = dict(locals())
|
||||||
|
parser_cfg['features'] = lang.Defaults.parser_features
|
||||||
|
entity_cfg['features'] = lang.Defaults.entity_features
|
||||||
|
gold_train = list(read_gold_json(train_path))
|
||||||
|
gold_dev = list(read_gold_json(dev_path))
|
||||||
|
|
||||||
|
train_model(lang, gold_train, gold_dev, output_path, tagger_cfg, parser_cfg,
|
||||||
|
entity_cfg, n_iter)
|
||||||
|
scorer = evaluate(lang, list(read_gold_json(dev_loc)), output_path)
|
||||||
|
print_results(scorer)
|
||||||
|
|
||||||
|
|
||||||
|
def train_config(config):
|
||||||
|
config_path = Path(config)
|
||||||
|
if not config_path.is_file():
|
||||||
|
util.sys_exit(config_path.as_posix(), title="Config file not found")
|
||||||
|
config = json.load(config_path)
|
||||||
|
for setting in []:
|
||||||
|
if setting not in config.keys():
|
||||||
|
util.sys_exit("{s} not found in config file.".format(s=setting),
|
||||||
|
title="Missing setting")
|
||||||
|
|
||||||
|
|
||||||
|
def train_model(Language, train_data, dev_data, output_path, tagger_cfg, parser_cfg,
|
||||||
|
entity_cfg, n_iter):
|
||||||
|
print("Itn.\tN weight\tN feats\tUAS\tNER F.\tTag %\tToken %")
|
||||||
|
|
||||||
|
with Language.train(output_path, train_data, tagger_cfg, parser_cfg, entity_cfg) as trainer:
|
||||||
|
loss = 0
|
||||||
|
for itn, epoch in enumerate(trainer.epochs(n_iter, augment_data=None)):
|
||||||
|
for doc, gold in epoch:
|
||||||
|
trainer.update(doc, gold)
|
||||||
|
dev_scores = trainer.evaluate(dev_data)
|
||||||
|
print_progress(itn, trainer.nlp.parser.model.nr_weight,
|
||||||
|
trainer.nlp.parser.model.nr_active_feat,
|
||||||
|
**dev_scores.scores)
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate(Language, gold_tuples, output_path):
|
||||||
|
print("Load parser", output_path)
|
||||||
|
nlp = Language(path=output_path)
|
||||||
|
scorer = Scorer()
|
||||||
|
for raw_text, sents in gold_tuples:
|
||||||
|
sents = merge_sents(sents)
|
||||||
|
for annot_tuples, brackets in sents:
|
||||||
|
if raw_text is None:
|
||||||
|
tokens = nlp.tokenizer.tokens_from_list(annot_tuples[1])
|
||||||
|
nlp.tagger(tokens)
|
||||||
|
nlp.parser(tokens)
|
||||||
|
nlp.entity(tokens)
|
||||||
|
else:
|
||||||
|
tokens = nlp(raw_text)
|
||||||
|
gold = GoldParse.from_annot_tuples(tokens, annot_tuples)
|
||||||
|
scorer.score(tokens, gold)
|
||||||
|
return scorer
|
||||||
|
|
||||||
|
|
||||||
|
def check_dirs(input_path, train_path, dev_path):
|
||||||
|
if not output_path.exists():
|
||||||
|
util.sys_exit(output_path.as_posix(), title="Output directory not found")
|
||||||
|
if not train_path.exists() and train_path.is_file():
|
||||||
|
util.sys_exit(train_path.as_posix(), title="Training data not found")
|
||||||
|
|
||||||
|
|
||||||
|
def print_progress(itn, nr_weight, nr_active_feat, **scores):
|
||||||
|
tpl = '{:d}\t{:d}\t{:d}\t{uas:.3f}\t{ents_f:.3f}\t{tags_acc:.3f}\t{token_acc:.3f}'
|
||||||
|
print(tpl.format(itn, nr_weight, nr_active_feat, **scores))
|
||||||
|
|
||||||
|
|
||||||
|
def print_results(scorer):
|
||||||
|
results = {'TOK': scorer.token_acc, 'POS': scorer.tags_acc, 'UAS': scorer.uas,
|
||||||
|
'LAS': scorer.las, 'NER P': scorer.ents_p, 'NER R': scorer.ents_r,
|
||||||
|
'NER F': scorer.ents_f}
|
||||||
|
util.print_table(results, title="Results")
|
Loading…
Reference in New Issue
Block a user