mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-30 23:47:31 +03:00 
			
		
		
		
	Support zipped vector files in init-model
This commit is contained in:
		
							parent
							
								
									270fcfd925
								
							
						
					
					
						commit
						db50ac524e
					
				|  | @ -10,6 +10,7 @@ from pathlib import Path | |||
| from preshed.counter import PreshCounter | ||||
| import tarfile | ||||
| import gzip | ||||
| import zipfile | ||||
| 
 | ||||
| from ._messages import Messages | ||||
| from ..vectors import Vectors | ||||
|  | @ -54,14 +55,19 @@ def init_model(lang, output_dir, freqs_loc=None, clusters_loc=None, vectors_loc= | |||
| def open_file(loc): | ||||
|     '''Handle .gz, .tar.gz or unzipped files''' | ||||
|     loc = ensure_path(loc) | ||||
|     print("Open loc") | ||||
|     if tarfile.is_tarfile(str(loc)): | ||||
|         return tarfile.open(str(loc), 'r:gz') | ||||
|     elif loc.parts[-1].endswith('gz'): | ||||
|         return (line.decode('utf8') for line in gzip.open(str(loc), 'r')) | ||||
|     elif loc.parts[-1].endswith('zip'): | ||||
|         zip_file = zipfile.ZipFile(str(loc)) | ||||
|         names = zip_file.namelist() | ||||
|         file_ = zip_file.open(names[0]) | ||||
|         return (line.decode('utf8') for line in file_) | ||||
|     else: | ||||
|         return loc.open('r', encoding='utf8') | ||||
| 
 | ||||
| 
 | ||||
| def create_model(lang, probs, oov_prob, clusters, vectors_data, vector_keys, prune_vectors): | ||||
|     print("Creating model...") | ||||
|     lang_class = get_lang_class(lang) | ||||
|  | @ -104,8 +110,12 @@ def read_vectors(vectors_loc): | |||
|     vectors_data = numpy.zeros(shape=shape, dtype='f') | ||||
|     vectors_keys = [] | ||||
|     for i, line in enumerate(tqdm(f)): | ||||
|         pieces = line.split() | ||||
|         line = line.rstrip() | ||||
|         pieces = line.rsplit(' ', vectors_data.shape[1]+1) | ||||
|         word = pieces.pop(0) | ||||
|         if len(pieces) != vectors_data.shape[1]: | ||||
|             print(word, repr(line)) | ||||
|             raise ValueError("Bad line in file") | ||||
|         vectors_data[i] = numpy.asarray(pieces, dtype='f') | ||||
|         vectors_keys.append(word) | ||||
|     return vectors_data, vectors_keys | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user