mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 18:26:30 +03:00
Fix init-model for npz vectors
This commit is contained in:
parent
59d655e8d0
commit
dee8bdb900
|
@ -66,14 +66,14 @@ def init_model(lang, output_dir, freqs_loc=None, clusters_loc=None, jsonl_loc=No
|
||||||
if freqs_loc is not None and not freqs_loc.exists():
|
if freqs_loc is not None and not freqs_loc.exists():
|
||||||
prints(freqs_loc, title=Messages.M037, exits=1)
|
prints(freqs_loc, title=Messages.M037, exits=1)
|
||||||
lex_attrs = read_attrs_from_deprecated(freqs_loc, clusters_loc)
|
lex_attrs = read_attrs_from_deprecated(freqs_loc, clusters_loc)
|
||||||
vectors_loc = ensure_path(vectors_loc)
|
|
||||||
if vectors_loc and vectors_loc.parts[-1].endswith('.npz'):
|
nlp = create_model(lang, lex_attrs)
|
||||||
vectors_data = numpy.load(vectors_loc.open('rb'))
|
if vectors_loc is not None:
|
||||||
vector_keys = [lex['orth'] for lex in lex_attrs
|
add_vectors(nlp, vectors_loc, prune_vectors)
|
||||||
if 'id' in lex and lex['id'] < vectors_data.shape[0]]
|
vec_added = len(nlp.vocab.vectors)
|
||||||
else:
|
lex_added = len(nlp.vocab)
|
||||||
vectors_data, vector_keys = read_vectors(vectors_loc) if vectors_loc else (None, None)
|
prints(Messages.M039.format(entries=lex_added, vectors=vec_added),
|
||||||
nlp = create_model(lang, lex_attrs, vectors_data, vector_keys, prune_vectors)
|
title=Messages.M038)
|
||||||
if not output_dir.exists():
|
if not output_dir.exists():
|
||||||
output_dir.mkdir()
|
output_dir.mkdir()
|
||||||
nlp.to_disk(output_dir)
|
nlp.to_disk(output_dir)
|
||||||
|
@ -112,7 +112,7 @@ def read_attrs_from_deprecated(freqs_loc, clusters_loc):
|
||||||
return lex_attrs
|
return lex_attrs
|
||||||
|
|
||||||
|
|
||||||
def create_model(lang, lex_attrs, vectors_data, vector_keys, prune_vectors):
|
def create_model(lang, lex_attrs):
|
||||||
print("Creating model...")
|
print("Creating model...")
|
||||||
lang_class = get_lang_class(lang)
|
lang_class = get_lang_class(lang)
|
||||||
nlp = lang_class()
|
nlp = lang_class()
|
||||||
|
@ -120,28 +120,38 @@ def create_model(lang, lex_attrs, vectors_data, vector_keys, prune_vectors):
|
||||||
lexeme.rank = 0
|
lexeme.rank = 0
|
||||||
lex_added = 0
|
lex_added = 0
|
||||||
for attrs in lex_attrs:
|
for attrs in lex_attrs:
|
||||||
|
if 'settings' in attrs:
|
||||||
|
continue
|
||||||
lexeme = nlp.vocab[attrs['orth']]
|
lexeme = nlp.vocab[attrs['orth']]
|
||||||
lexeme.set_attrs(**intify_attrs(attrs))
|
lexeme.set_attrs(**attrs)
|
||||||
lexeme.is_oov = False
|
lexeme.is_oov = False
|
||||||
lex_added += 1
|
lex_added += 1
|
||||||
lex_added += 1
|
lex_added += 1
|
||||||
oov_prob = min(lex.prob for lex in nlp.vocab)
|
oov_prob = min(lex.prob for lex in nlp.vocab)
|
||||||
nlp.vocab.cfg.update({'oov_prob': oov_prob-1})
|
nlp.vocab.cfg.update({'oov_prob': oov_prob-1})
|
||||||
if vector_keys is not None:
|
|
||||||
for word in vector_keys:
|
|
||||||
if word not in nlp.vocab:
|
|
||||||
lexeme = nlp.vocab[word]
|
|
||||||
lexeme.is_oov = False
|
|
||||||
lex_added += 1
|
|
||||||
if vectors_data is not None:
|
|
||||||
nlp.vocab.vectors = Vectors(data=vectors_data, keys=vector_keys)
|
|
||||||
if prune_vectors >= 1:
|
|
||||||
nlp.vocab.prune_vectors(prune_vectors)
|
|
||||||
vec_added = len(nlp.vocab.vectors)
|
|
||||||
prints(Messages.M039.format(entries=lex_added, vectors=vec_added),
|
|
||||||
title=Messages.M038)
|
|
||||||
return nlp
|
return nlp
|
||||||
|
|
||||||
|
def add_vectors(nlp, vectors_loc, prune_vectors):
|
||||||
|
vectors_loc = ensure_path(vectors_loc)
|
||||||
|
if vectors_loc and vectors_loc.parts[-1].endswith('.npz'):
|
||||||
|
nlp.vocab.vectors = Vectors(data=numpy.load(vectors_loc.open('rb')))
|
||||||
|
for lex in nlp.vocab:
|
||||||
|
if lex.rank:
|
||||||
|
nlp.vocab.vectors.add(lex.orth, row=lex.rank)
|
||||||
|
else:
|
||||||
|
vectors_data, vector_keys = read_vectors(vectors_loc) if vectors_loc else (None, None)
|
||||||
|
if vector_keys is not None:
|
||||||
|
for word in vector_keys:
|
||||||
|
if word not in nlp.vocab:
|
||||||
|
lexeme = nlp.vocab[word]
|
||||||
|
lexeme.is_oov = False
|
||||||
|
lex_added += 1
|
||||||
|
if vectors_data is not None:
|
||||||
|
nlp.vocab.vectors = Vectors(data=vectors_data, keys=vector_keys)
|
||||||
|
nlp.vocab.vectors.name = '%s_model.vectors' % nlp.meta['lang']
|
||||||
|
nlp.meta['vectors']['name'] = nlp.vocab.vectors.name
|
||||||
|
if prune_vectors >= 1:
|
||||||
|
nlp.vocab.prune_vectors(prune_vectors)
|
||||||
|
|
||||||
def read_vectors(vectors_loc):
|
def read_vectors(vectors_loc):
|
||||||
print("Reading vectors from %s" % vectors_loc)
|
print("Reading vectors from %s" % vectors_loc)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user