diff --git a/spacy/cli/init_model.py b/spacy/cli/init_model.py index e685377a9..0a9879213 100644 --- a/spacy/cli/init_model.py +++ b/spacy/cli/init_model.py @@ -66,14 +66,14 @@ def init_model(lang, output_dir, freqs_loc=None, clusters_loc=None, jsonl_loc=No if freqs_loc is not None and not freqs_loc.exists(): prints(freqs_loc, title=Messages.M037, exits=1) lex_attrs = read_attrs_from_deprecated(freqs_loc, clusters_loc) - vectors_loc = ensure_path(vectors_loc) - if vectors_loc and vectors_loc.parts[-1].endswith('.npz'): - vectors_data = numpy.load(vectors_loc.open('rb')) - vector_keys = [lex['orth'] for lex in lex_attrs - if 'id' in lex and lex['id'] < vectors_data.shape[0]] - else: - vectors_data, vector_keys = read_vectors(vectors_loc) if vectors_loc else (None, None) - nlp = create_model(lang, lex_attrs, vectors_data, vector_keys, prune_vectors) + + nlp = create_model(lang, lex_attrs) + if vectors_loc is not None: + add_vectors(nlp, vectors_loc, prune_vectors) + vec_added = len(nlp.vocab.vectors) + lex_added = len(nlp.vocab) + prints(Messages.M039.format(entries=lex_added, vectors=vec_added), + title=Messages.M038) if not output_dir.exists(): output_dir.mkdir() nlp.to_disk(output_dir) @@ -112,7 +112,7 @@ def read_attrs_from_deprecated(freqs_loc, clusters_loc): return lex_attrs -def create_model(lang, lex_attrs, vectors_data, vector_keys, prune_vectors): +def create_model(lang, lex_attrs): print("Creating model...") lang_class = get_lang_class(lang) nlp = lang_class() @@ -120,28 +120,38 @@ def create_model(lang, lex_attrs, vectors_data, vector_keys, prune_vectors): lexeme.rank = 0 lex_added = 0 for attrs in lex_attrs: + if 'settings' in attrs: + continue lexeme = nlp.vocab[attrs['orth']] - lexeme.set_attrs(**intify_attrs(attrs)) + lexeme.set_attrs(**attrs) lexeme.is_oov = False lex_added += 1 lex_added += 1 oov_prob = min(lex.prob for lex in nlp.vocab) nlp.vocab.cfg.update({'oov_prob': oov_prob-1}) - if vector_keys is not None: - for word in vector_keys: - if word not in nlp.vocab: - lexeme = nlp.vocab[word] - lexeme.is_oov = False - lex_added += 1 - if vectors_data is not None: - nlp.vocab.vectors = Vectors(data=vectors_data, keys=vector_keys) - if prune_vectors >= 1: - nlp.vocab.prune_vectors(prune_vectors) - vec_added = len(nlp.vocab.vectors) - prints(Messages.M039.format(entries=lex_added, vectors=vec_added), - title=Messages.M038) return nlp +def add_vectors(nlp, vectors_loc, prune_vectors): + vectors_loc = ensure_path(vectors_loc) + if vectors_loc and vectors_loc.parts[-1].endswith('.npz'): + nlp.vocab.vectors = Vectors(data=numpy.load(vectors_loc.open('rb'))) + for lex in nlp.vocab: + if lex.rank: + nlp.vocab.vectors.add(lex.orth, row=lex.rank) + else: + vectors_data, vector_keys = read_vectors(vectors_loc) if vectors_loc else (None, None) + if vector_keys is not None: + for word in vector_keys: + if word not in nlp.vocab: + lexeme = nlp.vocab[word] + lexeme.is_oov = False + lex_added += 1 + if vectors_data is not None: + nlp.vocab.vectors = Vectors(data=vectors_data, keys=vector_keys) + nlp.vocab.vectors.name = '%s_model.vectors' % nlp.meta['lang'] + nlp.meta['vectors']['name'] = nlp.vocab.vectors.name + if prune_vectors >= 1: + nlp.vocab.prune_vectors(prune_vectors) def read_vectors(vectors_loc): print("Reading vectors from %s" % vectors_loc)