diff --git a/spacy/cli/init_model.py b/spacy/cli/init_model.py index 37f862ef2..d0d876aed 100644 --- a/spacy/cli/init_model.py +++ b/spacy/cli/init_model.py @@ -37,7 +37,7 @@ def init_model_cli( clusters_loc: Optional[Path] = Opt(None, "--clusters-loc", "-c", help="Optional location of brown clusters data", exists=True), jsonl_loc: Optional[Path] = Opt(None, "--jsonl-loc", "-j", help="Location of JSONL-formatted attributes file", exists=True), vectors_loc: Optional[Path] = Opt(None, "--vectors-loc", "-v", help="Optional vectors file in Word2Vec format", exists=True), - prune_vectors: int = Opt(-1 , "--prune-vectors", "-V", help="Optional number of vectors to prune to"), + prune_vectors: int = Opt(-1, "--prune-vectors", "-V", help="Optional number of vectors to prune to"), truncate_vectors: int = Opt(0, "--truncate-vectors", "-t", help="Optional number of vectors to truncate to when reading in vectors file"), vectors_name: Optional[str] = Opt(None, "--vectors-name", "-vn", help="Optional name for the word vectors, e.g. en_core_web_lg.vectors"), model_name: Optional[str] = Opt(None, "--model-name", "-mn", help="Optional name for the model meta"), @@ -56,6 +56,7 @@ def init_model_cli( freqs_loc=freqs_loc, clusters_loc=clusters_loc, jsonl_loc=jsonl_loc, + vectors_loc=vectors_loc, prune_vectors=prune_vectors, truncate_vectors=truncate_vectors, vectors_name=vectors_name, @@ -228,7 +229,7 @@ def add_vectors( else: if vectors_loc: with msg.loading(f"Reading vectors from {vectors_loc}"): - vectors_data, vector_keys = read_vectors(msg, vectors_loc) + vectors_data, vector_keys = read_vectors(msg, vectors_loc, truncate_vectors) msg.good(f"Loaded vectors from {vectors_loc}") else: vectors_data, vector_keys = (None, None) @@ -247,7 +248,7 @@ def add_vectors( nlp.vocab.prune_vectors(prune_vectors) -def read_vectors(msg: Printer, vectors_loc: Path, truncate_vectors: int = 0): +def read_vectors(msg: Printer, vectors_loc: Path, truncate_vectors: int): f = open_file(vectors_loc) shape = tuple(int(size) for size in next(f).split()) if truncate_vectors >= 1: