mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-30 23:47:31 +03:00 
			
		
		
		
	Set pretrained_vectors in begin_training
This commit is contained in:
		
							parent
							
								
									95a9615221
								
							
						
					
					
						commit
						9bf6e93b3e
					
				|  | @ -516,6 +516,7 @@ class Tagger(Pipe): | |||
|             vocab.morphology = Morphology(vocab.strings, new_tag_map, | ||||
|                                           vocab.morphology.lemmatizer, | ||||
|                                           exc=vocab.morphology.exc) | ||||
|         self.cfg['pretrained_vectors'] = kwargs.get('pretrained_vectors') | ||||
|         if self.model is True: | ||||
|             self.model = self.Model(self.vocab.morphology.n_tags, **self.cfg) | ||||
|         link_vectors_to_models(self.vocab) | ||||
|  | @ -910,12 +911,15 @@ class TextCategorizer(Pipe): | |||
|         self.labels.append(label) | ||||
|         return 1 | ||||
| 
 | ||||
|     def begin_training(self, gold_tuples=tuple(), pipeline=None, sgd=None): | ||||
|     def begin_training(self, gold_tuples=tuple(), pipeline=None, sgd=None, | ||||
|                        **kwargs): | ||||
|         if pipeline and getattr(pipeline[0], 'name', None) == 'tensorizer': | ||||
|             token_vector_width = pipeline[0].model.nO | ||||
|         else: | ||||
|             token_vector_width = 64 | ||||
| 
 | ||||
|         if self.model is True: | ||||
|             self.cfg['pretrained_vectors'] = kwargs.get('pretrained_vectors') | ||||
|             self.model = self.Model(len(self.labels), token_vector_width, | ||||
|                                     **self.cfg) | ||||
|             link_vectors_to_models(self.vocab) | ||||
|  |  | |||
|  | @ -896,7 +896,6 @@ cdef class Parser: | |||
|             # TODO: Remove this once we don't have to handle previous models | ||||
|             if 'pretrained_dims' in self.cfg and 'pretrained_vectors' not in self.cfg: | ||||
|                 self.cfg['pretrained_vectors'] = self.vocab.vectors.name | ||||
|             print("Create parser model", self.cfg) | ||||
|             path = util.ensure_path(path) | ||||
|             if self.model is True: | ||||
|                 self.model, cfg = self.Model(**self.cfg) | ||||
|  | @ -944,7 +943,6 @@ cdef class Parser: | |||
|             # TODO: Remove this once we don't have to handle previous models | ||||
|             if 'pretrained_dims' in self.cfg and 'pretrained_vectors' not in self.cfg: | ||||
|                 self.cfg['pretrained_vectors'] = self.vocab.vectors.name | ||||
|             print("Create parser model", self.cfg) | ||||
|             if self.model is True: | ||||
|                 self.model, cfg = self.Model(**self.cfg) | ||||
|             else: | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user