diff --git a/examples/training/train_textcat.py b/examples/training/train_textcat.py index bd9e5ee18..c678632cd 100644 --- a/examples/training/train_textcat.py +++ b/examples/training/train_textcat.py @@ -24,8 +24,9 @@ from spacy.util import minibatch, compounding output_dir=("Optional output directory", "option", "o", Path), n_texts=("Number of texts to train from", "option", "t", int), n_iter=("Number of training iterations", "option", "n", int), + init_tok2vec=("Pretrained tok2vec weights", "option", "t2v", Path) ) -def main(model=None, output_dir=None, n_iter=20, n_texts=2000): +def main(model=None, output_dir=None, n_iter=20, n_texts=2000, init_tok2vec=None): if output_dir is not None: output_dir = Path(output_dir) if not output_dir.exists(): @@ -67,6 +68,9 @@ def main(model=None, output_dir=None, n_iter=20, n_texts=2000): other_pipes = [pipe for pipe in nlp.pipe_names if pipe != "textcat"] with nlp.disable_pipes(*other_pipes): # only train textcat optimizer = nlp.begin_training() + if init_tok2vec is not None: + with init_tok2vec.open("rb") as file_: + textcat.model.tok2vec.from_bytes(file_.read()) print("Training the model...") print("{:^5}\t{:^5}\t{:^5}\t{:^5}".format("LOSS", "P", "R", "F")) for i in range(n_iter):