mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 07:57:35 +03:00 
			
		
		
		
	Make pipe base class a bit less presumptuous
This commit is contained in:
		
							parent
							
								
									16b5144095
								
							
						
					
					
						commit
						06eb428ed1
					
				|  | @ -83,8 +83,12 @@ class Pipe(object): | ||||||
|         """ |         """ | ||||||
|         for docs in util.minibatch(stream, size=batch_size): |         for docs in util.minibatch(stream, size=batch_size): | ||||||
|             docs = list(docs) |             docs = list(docs) | ||||||
|             scores, tensors = self.predict(docs) |             predictions = self.predict(docs) | ||||||
|  |             if isinstance(predictions, tuple) and len(tuple) == 2: | ||||||
|  |                 scores, tensors = predictions | ||||||
|                 self.set_annotations(docs, scores, tensor=tensors) |                 self.set_annotations(docs, scores, tensor=tensors) | ||||||
|  |             else: | ||||||
|  |                 self.set_annotations(docs, predictions) | ||||||
|             yield from docs |             yield from docs | ||||||
| 
 | 
 | ||||||
|     def predict(self, docs): |     def predict(self, docs): | ||||||
|  | @ -133,6 +137,7 @@ class Pipe(object): | ||||||
|         If no model has been initialized yet, the model is added.""" |         If no model has been initialized yet, the model is added.""" | ||||||
|         if self.model is True: |         if self.model is True: | ||||||
|             self.model = self.Model(**self.cfg) |             self.model = self.Model(**self.cfg) | ||||||
|  |         if hasattr(self, "vocab"): | ||||||
|             link_vectors_to_models(self.vocab) |             link_vectors_to_models(self.vocab) | ||||||
|         if sgd is None: |         if sgd is None: | ||||||
|             sgd = self.create_optimizer() |             sgd = self.create_optimizer() | ||||||
|  | @ -153,6 +158,7 @@ class Pipe(object): | ||||||
|         serialize["cfg"] = lambda: srsly.json_dumps(self.cfg) |         serialize["cfg"] = lambda: srsly.json_dumps(self.cfg) | ||||||
|         if self.model not in (True, False, None): |         if self.model not in (True, False, None): | ||||||
|             serialize["model"] = self.model.to_bytes |             serialize["model"] = self.model.to_bytes | ||||||
|  |         if hasattr(self, "vocab"): | ||||||
|             serialize["vocab"] = self.vocab.to_bytes |             serialize["vocab"] = self.vocab.to_bytes | ||||||
|         exclude = util.get_serialization_exclude(serialize, exclude, kwargs) |         exclude = util.get_serialization_exclude(serialize, exclude, kwargs) | ||||||
|         return util.to_bytes(serialize, exclude) |         return util.to_bytes(serialize, exclude) | ||||||
|  | @ -173,6 +179,7 @@ class Pipe(object): | ||||||
| 
 | 
 | ||||||
|         deserialize = OrderedDict() |         deserialize = OrderedDict() | ||||||
|         deserialize["cfg"] = lambda b: self.cfg.update(srsly.json_loads(b)) |         deserialize["cfg"] = lambda b: self.cfg.update(srsly.json_loads(b)) | ||||||
|  |         if hasattr(self, "vocab"): | ||||||
|             deserialize["vocab"] = lambda b: self.vocab.from_bytes(b) |             deserialize["vocab"] = lambda b: self.vocab.from_bytes(b) | ||||||
|         deserialize["model"] = load_model |         deserialize["model"] = load_model | ||||||
|         exclude = util.get_serialization_exclude(deserialize, exclude, kwargs) |         exclude = util.get_serialization_exclude(deserialize, exclude, kwargs) | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user