Merge pull request #5687 from svlandeg/bugfix/init-model

Fixing init_model
This commit is contained in:
Ines Montani 2020-07-02 14:10:28 +02:00 committed by GitHub
commit ee8a830248
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -56,6 +56,7 @@ def init_model_cli(
freqs_loc=freqs_loc, freqs_loc=freqs_loc,
clusters_loc=clusters_loc, clusters_loc=clusters_loc,
jsonl_loc=jsonl_loc, jsonl_loc=jsonl_loc,
vectors_loc=vectors_loc,
prune_vectors=prune_vectors, prune_vectors=prune_vectors,
truncate_vectors=truncate_vectors, truncate_vectors=truncate_vectors,
vectors_name=vectors_name, vectors_name=vectors_name,
@ -228,7 +229,7 @@ def add_vectors(
else: else:
if vectors_loc: if vectors_loc:
with msg.loading(f"Reading vectors from {vectors_loc}"): with msg.loading(f"Reading vectors from {vectors_loc}"):
vectors_data, vector_keys = read_vectors(msg, vectors_loc) vectors_data, vector_keys = read_vectors(msg, vectors_loc, truncate_vectors)
msg.good(f"Loaded vectors from {vectors_loc}") msg.good(f"Loaded vectors from {vectors_loc}")
else: else:
vectors_data, vector_keys = (None, None) vectors_data, vector_keys = (None, None)
@ -247,7 +248,7 @@ def add_vectors(
nlp.vocab.prune_vectors(prune_vectors) nlp.vocab.prune_vectors(prune_vectors)
def read_vectors(msg: Printer, vectors_loc: Path, truncate_vectors: int = 0): def read_vectors(msg: Printer, vectors_loc: Path, truncate_vectors: int):
f = open_file(vectors_loc) f = open_file(vectors_loc)
shape = tuple(int(size) for size in next(f).split()) shape = tuple(int(size) for size in next(f).split())
if truncate_vectors >= 1: if truncate_vectors >= 1: