Fix begin_training if get_gold_tuples is None

This commit is contained in:
ines 2017-11-01 13:14:31 +01:00
parent affd3404ab
commit bfe17b7df1
6 changed files with 8 additions and 6 deletions

View File

@ -94,7 +94,7 @@ def main(model=None, output_dir=None, n_iter=100):
other_pipes = [pipe for pipe in nlp.pipe_names if pipe != 'parser']
with nlp.disable_pipes(*other_pipes): # only train parser
optimizer = nlp.begin_training(lambda: [])
optimizer = nlp.begin_training()
for itn in range(n_iter):
random.shuffle(TRAIN_DATA)
losses = {}

View File

@ -87,7 +87,7 @@ def main(model=None, new_model_name='animal', output_dir=None, n_iter=50):
other_pipes = [pipe for pipe in nlp.pipe_names if pipe != 'ner']
with nlp.disable_pipes(*other_pipes): # only train NER
random.seed(0)
optimizer = nlp.begin_training(lambda: [])
optimizer = nlp.begin_training()
for itn in range(n_iter):
losses = {}
gold_parses = get_gold_parses(nlp.make_doc, TRAIN_DATA)

View File

@ -64,7 +64,7 @@ def main(model=None, output_dir=None, n_iter=1000):
# get names of other pipes to disable them during training
other_pipes = [pipe for pipe in nlp.pipe_names if pipe != 'parser']
with nlp.disable_pipes(*other_pipes): # only train parser
optimizer = nlp.begin_training(lambda: [])
optimizer = nlp.begin_training()
for itn in range(n_iter):
random.shuffle(TRAIN_DATA)
losses = {}

View File

@ -61,7 +61,7 @@ def main(lang='en', output_dir=None, n_iter=25):
tagger = nlp.create_pipe('tagger')
nlp.add_pipe(tagger)
optimizer = nlp.begin_training(lambda: [])
optimizer = nlp.begin_training()
for i in range(n_iter):
random.shuffle(TRAIN_DATA)
losses = {}

View File

@ -59,7 +59,7 @@ def main(model=None, output_dir=None, n_iter=20):
# get names of other pipes to disable them during training
other_pipes = [pipe for pipe in nlp.pipe_names if pipe != 'textcat']
with nlp.disable_pipes(*other_pipes): # only train textcat
optimizer = nlp.begin_training(lambda: [])
optimizer = nlp.begin_training()
print("Training the model...")
print('{:^5}\t{:^5}\t{:^5}\t{:^5}'.format('LOSS', 'P', 'R', 'F'))
for i in range(n_iter):

View File

@ -436,8 +436,10 @@ class Language(object):
**cfg: Config parameters.
RETURNS: An optimizer
"""
if get_gold_tuples is None:
get_gold_tuples = lambda: []
# Populate vocab
if get_gold_tuples is not None:
else:
for _, annots_brackets in get_gold_tuples():
for annots, _ in annots_brackets:
for word in annots[1]: