From 49ef06d793b885c3bd634ac72f38be067246822a Mon Sep 17 00:00:00 2001 From: adrianeboyd Date: Wed, 20 May 2020 18:49:11 +0200 Subject: [PATCH] Add option for base model in init-model CLI (#5467) Intended for languages like Chinese with a custom tokenizer. --- spacy/cli/init_model.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/spacy/cli/init_model.py b/spacy/cli/init_model.py index 3311a5120..537afd10f 100644 --- a/spacy/cli/init_model.py +++ b/spacy/cli/init_model.py @@ -17,7 +17,7 @@ from wasabi import msg from ..vectors import Vectors from ..errors import Errors, Warnings -from ..util import ensure_path, get_lang_class, OOV_RANK +from ..util import ensure_path, get_lang_class, load_model, OOV_RANK try: import ftfy @@ -49,6 +49,7 @@ DEFAULT_OOV_PROB = -20 str, ), model_name=("Optional name for the model meta", "option", "mn", str), + base_model=("Base model (for languages with custom tokenizers)", "option", "b", str), ) def init_model( lang, @@ -61,6 +62,7 @@ def init_model( prune_vectors=-1, vectors_name=None, model_name=None, + base_model=None, ): """ Create a new model from raw data, like word frequencies, Brown clusters @@ -92,7 +94,7 @@ def init_model( lex_attrs = read_attrs_from_deprecated(freqs_loc, clusters_loc) with msg.loading("Creating model..."): - nlp = create_model(lang, lex_attrs, name=model_name) + nlp = create_model(lang, lex_attrs, name=model_name, base_model=base_model) msg.good("Successfully created model") if vectors_loc is not None: add_vectors(nlp, vectors_loc, truncate_vectors, prune_vectors, vectors_name) @@ -152,9 +154,16 @@ def read_attrs_from_deprecated(freqs_loc, clusters_loc): return lex_attrs -def create_model(lang, lex_attrs, name=None): - lang_class = get_lang_class(lang) - nlp = lang_class() +def create_model(lang, lex_attrs, name=None, base_model=None): + if base_model: + nlp = load_model(base_model) + # keep the tokenizer but remove any existing pipeline components due to + # potentially conflicting vectors + for pipe in nlp.pipe_names: + nlp.remove_pipe(pipe) + else: + lang_class = get_lang_class(lang) + nlp = lang_class() for lexeme in nlp.vocab: lexeme.rank = OOV_RANK for attrs in lex_attrs: