mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 16:07:41 +03:00 
			
		
		
		
	Support zipped vector files in init-model
This commit is contained in:
		
							parent
							
								
									270fcfd925
								
							
						
					
					
						commit
						db50ac524e
					
				|  | @ -10,6 +10,7 @@ from pathlib import Path | ||||||
| from preshed.counter import PreshCounter | from preshed.counter import PreshCounter | ||||||
| import tarfile | import tarfile | ||||||
| import gzip | import gzip | ||||||
|  | import zipfile | ||||||
| 
 | 
 | ||||||
| from ._messages import Messages | from ._messages import Messages | ||||||
| from ..vectors import Vectors | from ..vectors import Vectors | ||||||
|  | @ -54,14 +55,19 @@ def init_model(lang, output_dir, freqs_loc=None, clusters_loc=None, vectors_loc= | ||||||
| def open_file(loc): | def open_file(loc): | ||||||
|     '''Handle .gz, .tar.gz or unzipped files''' |     '''Handle .gz, .tar.gz or unzipped files''' | ||||||
|     loc = ensure_path(loc) |     loc = ensure_path(loc) | ||||||
|  |     print("Open loc") | ||||||
|     if tarfile.is_tarfile(str(loc)): |     if tarfile.is_tarfile(str(loc)): | ||||||
|         return tarfile.open(str(loc), 'r:gz') |         return tarfile.open(str(loc), 'r:gz') | ||||||
|     elif loc.parts[-1].endswith('gz'): |     elif loc.parts[-1].endswith('gz'): | ||||||
|         return (line.decode('utf8') for line in gzip.open(str(loc), 'r')) |         return (line.decode('utf8') for line in gzip.open(str(loc), 'r')) | ||||||
|  |     elif loc.parts[-1].endswith('zip'): | ||||||
|  |         zip_file = zipfile.ZipFile(str(loc)) | ||||||
|  |         names = zip_file.namelist() | ||||||
|  |         file_ = zip_file.open(names[0]) | ||||||
|  |         return (line.decode('utf8') for line in file_) | ||||||
|     else: |     else: | ||||||
|         return loc.open('r', encoding='utf8') |         return loc.open('r', encoding='utf8') | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
| def create_model(lang, probs, oov_prob, clusters, vectors_data, vector_keys, prune_vectors): | def create_model(lang, probs, oov_prob, clusters, vectors_data, vector_keys, prune_vectors): | ||||||
|     print("Creating model...") |     print("Creating model...") | ||||||
|     lang_class = get_lang_class(lang) |     lang_class = get_lang_class(lang) | ||||||
|  | @ -104,8 +110,12 @@ def read_vectors(vectors_loc): | ||||||
|     vectors_data = numpy.zeros(shape=shape, dtype='f') |     vectors_data = numpy.zeros(shape=shape, dtype='f') | ||||||
|     vectors_keys = [] |     vectors_keys = [] | ||||||
|     for i, line in enumerate(tqdm(f)): |     for i, line in enumerate(tqdm(f)): | ||||||
|         pieces = line.split() |         line = line.rstrip() | ||||||
|  |         pieces = line.rsplit(' ', vectors_data.shape[1]+1) | ||||||
|         word = pieces.pop(0) |         word = pieces.pop(0) | ||||||
|  |         if len(pieces) != vectors_data.shape[1]: | ||||||
|  |             print(word, repr(line)) | ||||||
|  |             raise ValueError("Bad line in file") | ||||||
|         vectors_data[i] = numpy.asarray(pieces, dtype='f') |         vectors_data[i] = numpy.asarray(pieces, dtype='f') | ||||||
|         vectors_keys.append(word) |         vectors_keys.append(word) | ||||||
|     return vectors_data, vectors_keys |     return vectors_data, vectors_keys | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user