mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-26 05:31:15 +03:00 
			
		
		
		
	Fix init-model for npz vectors
This commit is contained in:
		
							parent
							
								
									59d655e8d0
								
							
						
					
					
						commit
						dee8bdb900
					
				|  | @ -66,14 +66,14 @@ def init_model(lang, output_dir, freqs_loc=None, clusters_loc=None, jsonl_loc=No | ||||||
|         if freqs_loc is not None and not freqs_loc.exists(): |         if freqs_loc is not None and not freqs_loc.exists(): | ||||||
|             prints(freqs_loc, title=Messages.M037, exits=1) |             prints(freqs_loc, title=Messages.M037, exits=1) | ||||||
|         lex_attrs = read_attrs_from_deprecated(freqs_loc, clusters_loc) |         lex_attrs = read_attrs_from_deprecated(freqs_loc, clusters_loc) | ||||||
|     vectors_loc = ensure_path(vectors_loc) | 
 | ||||||
|     if vectors_loc and vectors_loc.parts[-1].endswith('.npz'): |     nlp = create_model(lang, lex_attrs) | ||||||
|         vectors_data = numpy.load(vectors_loc.open('rb')) |     if vectors_loc is not None: | ||||||
|         vector_keys = [lex['orth'] for lex in lex_attrs |         add_vectors(nlp, vectors_loc, prune_vectors) | ||||||
|                        if 'id' in lex and lex['id'] < vectors_data.shape[0]] |     vec_added = len(nlp.vocab.vectors) | ||||||
|     else: |     lex_added = len(nlp.vocab) | ||||||
|         vectors_data, vector_keys = read_vectors(vectors_loc) if vectors_loc else (None, None) |     prints(Messages.M039.format(entries=lex_added, vectors=vec_added), | ||||||
|     nlp = create_model(lang, lex_attrs, vectors_data, vector_keys, prune_vectors) |            title=Messages.M038) | ||||||
|     if not output_dir.exists(): |     if not output_dir.exists(): | ||||||
|         output_dir.mkdir() |         output_dir.mkdir() | ||||||
|     nlp.to_disk(output_dir) |     nlp.to_disk(output_dir) | ||||||
|  | @ -112,7 +112,7 @@ def read_attrs_from_deprecated(freqs_loc, clusters_loc): | ||||||
|     return lex_attrs |     return lex_attrs | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def create_model(lang, lex_attrs, vectors_data, vector_keys, prune_vectors): | def create_model(lang, lex_attrs): | ||||||
|     print("Creating model...") |     print("Creating model...") | ||||||
|     lang_class = get_lang_class(lang) |     lang_class = get_lang_class(lang) | ||||||
|     nlp = lang_class() |     nlp = lang_class() | ||||||
|  | @ -120,13 +120,26 @@ def create_model(lang, lex_attrs, vectors_data, vector_keys, prune_vectors): | ||||||
|         lexeme.rank = 0 |         lexeme.rank = 0 | ||||||
|     lex_added = 0 |     lex_added = 0 | ||||||
|     for attrs in lex_attrs: |     for attrs in lex_attrs: | ||||||
|  |         if 'settings' in attrs: | ||||||
|  |             continue | ||||||
|         lexeme = nlp.vocab[attrs['orth']] |         lexeme = nlp.vocab[attrs['orth']] | ||||||
|         lexeme.set_attrs(**intify_attrs(attrs)) |         lexeme.set_attrs(**attrs) | ||||||
|         lexeme.is_oov = False |         lexeme.is_oov = False | ||||||
|         lex_added += 1 |         lex_added += 1 | ||||||
|         lex_added += 1 |         lex_added += 1 | ||||||
|     oov_prob = min(lex.prob for lex in nlp.vocab) |     oov_prob = min(lex.prob for lex in nlp.vocab) | ||||||
|     nlp.vocab.cfg.update({'oov_prob': oov_prob-1}) |     nlp.vocab.cfg.update({'oov_prob': oov_prob-1}) | ||||||
|  |     return nlp | ||||||
|  | 
 | ||||||
|  | def add_vectors(nlp, vectors_loc, prune_vectors): | ||||||
|  |     vectors_loc = ensure_path(vectors_loc) | ||||||
|  |     if vectors_loc and vectors_loc.parts[-1].endswith('.npz'): | ||||||
|  |         nlp.vocab.vectors = Vectors(data=numpy.load(vectors_loc.open('rb'))) | ||||||
|  |         for lex in nlp.vocab: | ||||||
|  |             if lex.rank: | ||||||
|  |                 nlp.vocab.vectors.add(lex.orth, row=lex.rank) | ||||||
|  |     else: | ||||||
|  |         vectors_data, vector_keys = read_vectors(vectors_loc) if vectors_loc else (None, None) | ||||||
|         if vector_keys is not None: |         if vector_keys is not None: | ||||||
|             for word in vector_keys: |             for word in vector_keys: | ||||||
|                 if word not in nlp.vocab: |                 if word not in nlp.vocab: | ||||||
|  | @ -135,13 +148,10 @@ def create_model(lang, lex_attrs, vectors_data, vector_keys, prune_vectors): | ||||||
|                     lex_added += 1 |                     lex_added += 1 | ||||||
|         if vectors_data is not None: |         if vectors_data is not None: | ||||||
|             nlp.vocab.vectors = Vectors(data=vectors_data, keys=vector_keys) |             nlp.vocab.vectors = Vectors(data=vectors_data, keys=vector_keys) | ||||||
|  |     nlp.vocab.vectors.name = '%s_model.vectors' % nlp.meta['lang'] | ||||||
|  |     nlp.meta['vectors']['name'] = nlp.vocab.vectors.name | ||||||
|     if prune_vectors >= 1: |     if prune_vectors >= 1: | ||||||
|         nlp.vocab.prune_vectors(prune_vectors) |         nlp.vocab.prune_vectors(prune_vectors) | ||||||
|     vec_added = len(nlp.vocab.vectors) |  | ||||||
|     prints(Messages.M039.format(entries=lex_added, vectors=vec_added), |  | ||||||
|            title=Messages.M038) |  | ||||||
|     return nlp |  | ||||||
| 
 |  | ||||||
| 
 | 
 | ||||||
| def read_vectors(vectors_loc): | def read_vectors(vectors_loc): | ||||||
|     print("Reading vectors from %s" % vectors_loc) |     print("Reading vectors from %s" % vectors_loc) | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user