Fix init-model for npz vectors

This commit is contained in:
Matthew Honnibal 2018-07-04 02:29:48 +02:00
parent 59d655e8d0
commit dee8bdb900

View File

@ -66,14 +66,14 @@ def init_model(lang, output_dir, freqs_loc=None, clusters_loc=None, jsonl_loc=No
if freqs_loc is not None and not freqs_loc.exists():
prints(freqs_loc, title=Messages.M037, exits=1)
lex_attrs = read_attrs_from_deprecated(freqs_loc, clusters_loc)
vectors_loc = ensure_path(vectors_loc)
if vectors_loc and vectors_loc.parts[-1].endswith('.npz'):
vectors_data = numpy.load(vectors_loc.open('rb'))
vector_keys = [lex['orth'] for lex in lex_attrs
if 'id' in lex and lex['id'] < vectors_data.shape[0]]
else:
vectors_data, vector_keys = read_vectors(vectors_loc) if vectors_loc else (None, None)
nlp = create_model(lang, lex_attrs, vectors_data, vector_keys, prune_vectors)
nlp = create_model(lang, lex_attrs)
if vectors_loc is not None:
add_vectors(nlp, vectors_loc, prune_vectors)
vec_added = len(nlp.vocab.vectors)
lex_added = len(nlp.vocab)
prints(Messages.M039.format(entries=lex_added, vectors=vec_added),
title=Messages.M038)
if not output_dir.exists():
output_dir.mkdir()
nlp.to_disk(output_dir)
@ -112,7 +112,7 @@ def read_attrs_from_deprecated(freqs_loc, clusters_loc):
return lex_attrs
def create_model(lang, lex_attrs, vectors_data, vector_keys, prune_vectors):
def create_model(lang, lex_attrs):
print("Creating model...")
lang_class = get_lang_class(lang)
nlp = lang_class()
@ -120,28 +120,38 @@ def create_model(lang, lex_attrs, vectors_data, vector_keys, prune_vectors):
lexeme.rank = 0
lex_added = 0
for attrs in lex_attrs:
if 'settings' in attrs:
continue
lexeme = nlp.vocab[attrs['orth']]
lexeme.set_attrs(**intify_attrs(attrs))
lexeme.set_attrs(**attrs)
lexeme.is_oov = False
lex_added += 1
lex_added += 1
oov_prob = min(lex.prob for lex in nlp.vocab)
nlp.vocab.cfg.update({'oov_prob': oov_prob-1})
if vector_keys is not None:
for word in vector_keys:
if word not in nlp.vocab:
lexeme = nlp.vocab[word]
lexeme.is_oov = False
lex_added += 1
if vectors_data is not None:
nlp.vocab.vectors = Vectors(data=vectors_data, keys=vector_keys)
if prune_vectors >= 1:
nlp.vocab.prune_vectors(prune_vectors)
vec_added = len(nlp.vocab.vectors)
prints(Messages.M039.format(entries=lex_added, vectors=vec_added),
title=Messages.M038)
return nlp
def add_vectors(nlp, vectors_loc, prune_vectors):
vectors_loc = ensure_path(vectors_loc)
if vectors_loc and vectors_loc.parts[-1].endswith('.npz'):
nlp.vocab.vectors = Vectors(data=numpy.load(vectors_loc.open('rb')))
for lex in nlp.vocab:
if lex.rank:
nlp.vocab.vectors.add(lex.orth, row=lex.rank)
else:
vectors_data, vector_keys = read_vectors(vectors_loc) if vectors_loc else (None, None)
if vector_keys is not None:
for word in vector_keys:
if word not in nlp.vocab:
lexeme = nlp.vocab[word]
lexeme.is_oov = False
lex_added += 1
if vectors_data is not None:
nlp.vocab.vectors = Vectors(data=vectors_data, keys=vector_keys)
nlp.vocab.vectors.name = '%s_model.vectors' % nlp.meta['lang']
nlp.meta['vectors']['name'] = nlp.vocab.vectors.name
if prune_vectors >= 1:
nlp.vocab.prune_vectors(prune_vectors)
def read_vectors(vectors_loc):
print("Reading vectors from %s" % vectors_loc)