diff --git a/examples/training/train_ner.py b/examples/training/train_ner.py index 8bb01b87f..797dbcb9c 100644 --- a/examples/training/train_ner.py +++ b/examples/training/train_ner.py @@ -56,7 +56,10 @@ def main(model=None, output_dir=None, n_iter=100): # get names of other pipes to disable them during training other_pipes = [pipe for pipe in nlp.pipe_names if pipe != "ner"] with nlp.disable_pipes(*other_pipes): # only train NER - optimizer = nlp.begin_training() + # reset and initialize the weights randomly – but only if we're + # training a new model + if model is None: + nlp.begin_training() for itn in range(n_iter): random.shuffle(TRAIN_DATA) losses = {} @@ -68,7 +71,6 @@ def main(model=None, output_dir=None, n_iter=100): texts, # batch of texts annotations, # batch of annotations drop=0.5, # dropout - make it harder to memorise data - sgd=optimizer, # callable to update weights losses=losses, ) print("Losses", losses)