mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 16:07:41 +03:00 
			
		
		
		
	Add option for base model in init-model CLI (#5467)
Intended for languages like Chinese with a custom tokenizer.
This commit is contained in:
		
							parent
							
								
									9393253b66
								
							
						
					
					
						commit
						49ef06d793
					
				|  | @ -17,7 +17,7 @@ from wasabi import msg | ||||||
| 
 | 
 | ||||||
| from ..vectors import Vectors | from ..vectors import Vectors | ||||||
| from ..errors import Errors, Warnings | from ..errors import Errors, Warnings | ||||||
| from ..util import ensure_path, get_lang_class, OOV_RANK | from ..util import ensure_path, get_lang_class, load_model, OOV_RANK | ||||||
| 
 | 
 | ||||||
| try: | try: | ||||||
|     import ftfy |     import ftfy | ||||||
|  | @ -49,6 +49,7 @@ DEFAULT_OOV_PROB = -20 | ||||||
|         str, |         str, | ||||||
|     ), |     ), | ||||||
|     model_name=("Optional name for the model meta", "option", "mn", str), |     model_name=("Optional name for the model meta", "option", "mn", str), | ||||||
|  |     base_model=("Base model (for languages with custom tokenizers)", "option", "b", str), | ||||||
| ) | ) | ||||||
| def init_model( | def init_model( | ||||||
|     lang, |     lang, | ||||||
|  | @ -61,6 +62,7 @@ def init_model( | ||||||
|     prune_vectors=-1, |     prune_vectors=-1, | ||||||
|     vectors_name=None, |     vectors_name=None, | ||||||
|     model_name=None, |     model_name=None, | ||||||
|  |     base_model=None, | ||||||
| ): | ): | ||||||
|     """ |     """ | ||||||
|     Create a new model from raw data, like word frequencies, Brown clusters |     Create a new model from raw data, like word frequencies, Brown clusters | ||||||
|  | @ -92,7 +94,7 @@ def init_model( | ||||||
|         lex_attrs = read_attrs_from_deprecated(freqs_loc, clusters_loc) |         lex_attrs = read_attrs_from_deprecated(freqs_loc, clusters_loc) | ||||||
| 
 | 
 | ||||||
|     with msg.loading("Creating model..."): |     with msg.loading("Creating model..."): | ||||||
|         nlp = create_model(lang, lex_attrs, name=model_name) |         nlp = create_model(lang, lex_attrs, name=model_name, base_model=base_model) | ||||||
|     msg.good("Successfully created model") |     msg.good("Successfully created model") | ||||||
|     if vectors_loc is not None: |     if vectors_loc is not None: | ||||||
|         add_vectors(nlp, vectors_loc, truncate_vectors, prune_vectors, vectors_name) |         add_vectors(nlp, vectors_loc, truncate_vectors, prune_vectors, vectors_name) | ||||||
|  | @ -152,7 +154,14 @@ def read_attrs_from_deprecated(freqs_loc, clusters_loc): | ||||||
|     return lex_attrs |     return lex_attrs | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def create_model(lang, lex_attrs, name=None): | def create_model(lang, lex_attrs, name=None, base_model=None): | ||||||
|  |     if base_model: | ||||||
|  |         nlp = load_model(base_model) | ||||||
|  |         # keep the tokenizer but remove any existing pipeline components due to | ||||||
|  |         # potentially conflicting vectors | ||||||
|  |         for pipe in nlp.pipe_names: | ||||||
|  |             nlp.remove_pipe(pipe) | ||||||
|  |     else: | ||||||
|         lang_class = get_lang_class(lang) |         lang_class = get_lang_class(lang) | ||||||
|         nlp = lang_class() |         nlp = lang_class() | ||||||
|     for lexeme in nlp.vocab: |     for lexeme in nlp.vocab: | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user