mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 17:06:29 +03:00
Add gold_preproc flag to cli/train
This commit is contained in:
parent
42fa84075f
commit
84bb543e4d
|
@ -32,10 +32,12 @@ from ..compat import json_dumps
|
||||||
resume=("Whether to resume training", "flag", "R", bool),
|
resume=("Whether to resume training", "flag", "R", bool),
|
||||||
no_tagger=("Don't train tagger", "flag", "T", bool),
|
no_tagger=("Don't train tagger", "flag", "T", bool),
|
||||||
no_parser=("Don't train parser", "flag", "P", bool),
|
no_parser=("Don't train parser", "flag", "P", bool),
|
||||||
no_entities=("Don't train NER", "flag", "N", bool)
|
no_entities=("Don't train NER", "flag", "N", bool),
|
||||||
|
gold_preproc=("Use gold preprocessing", "flag", "G", bool),
|
||||||
)
|
)
|
||||||
def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0,
|
def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0,
|
||||||
use_gpu=-1, resume=False, no_tagger=False, no_parser=False, no_entities=False):
|
use_gpu=-1, resume=False, no_tagger=False, no_parser=False, no_entities=False,
|
||||||
|
gold_preproc=False):
|
||||||
"""
|
"""
|
||||||
Train a model. Expects data in spaCy's JSON format.
|
Train a model. Expects data in spaCy's JSON format.
|
||||||
"""
|
"""
|
||||||
|
@ -86,7 +88,7 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0,
|
||||||
i += 20
|
i += 20
|
||||||
with tqdm.tqdm(total=n_train_words, leave=False) as pbar:
|
with tqdm.tqdm(total=n_train_words, leave=False) as pbar:
|
||||||
train_docs = corpus.train_docs(nlp, projectivize=True,
|
train_docs = corpus.train_docs(nlp, projectivize=True,
|
||||||
gold_preproc=False, max_length=0)
|
gold_preproc=gold_preproc, max_length=0)
|
||||||
losses = {}
|
losses = {}
|
||||||
for batch in minibatch(train_docs, size=batch_sizes):
|
for batch in minibatch(train_docs, size=batch_sizes):
|
||||||
docs, golds = zip(*batch)
|
docs, golds = zip(*batch)
|
||||||
|
@ -104,7 +106,7 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0,
|
||||||
scorer = nlp_loaded.evaluate(
|
scorer = nlp_loaded.evaluate(
|
||||||
corpus.dev_docs(
|
corpus.dev_docs(
|
||||||
nlp_loaded,
|
nlp_loaded,
|
||||||
gold_preproc=False))
|
gold_preproc=gold_preproc))
|
||||||
acc_loc =(output_path / ('model%d' % i) / 'accuracy.json')
|
acc_loc =(output_path / ('model%d' % i) / 'accuracy.json')
|
||||||
with acc_loc.open('w') as file_:
|
with acc_loc.open('w') as file_:
|
||||||
file_.write(json_dumps(scorer.scores))
|
file_.write(json_dumps(scorer.scores))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user