mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 07:57:35 +03:00 
			
		
		
		
	Make pipe base class a bit less presumptuous
This commit is contained in:
		
							parent
							
								
									16b5144095
								
							
						
					
					
						commit
						06eb428ed1
					
				|  | @ -83,8 +83,12 @@ class Pipe(object): | |||
|         """ | ||||
|         for docs in util.minibatch(stream, size=batch_size): | ||||
|             docs = list(docs) | ||||
|             scores, tensors = self.predict(docs) | ||||
|             self.set_annotations(docs, scores, tensor=tensors) | ||||
|             predictions = self.predict(docs) | ||||
|             if isinstance(predictions, tuple) and len(tuple) == 2: | ||||
|                 scores, tensors = predictions | ||||
|                 self.set_annotations(docs, scores, tensor=tensors) | ||||
|             else: | ||||
|                 self.set_annotations(docs, predictions) | ||||
|             yield from docs | ||||
| 
 | ||||
|     def predict(self, docs): | ||||
|  | @ -133,7 +137,8 @@ class Pipe(object): | |||
|         If no model has been initialized yet, the model is added.""" | ||||
|         if self.model is True: | ||||
|             self.model = self.Model(**self.cfg) | ||||
|         link_vectors_to_models(self.vocab) | ||||
|         if hasattr(self, "vocab"): | ||||
|             link_vectors_to_models(self.vocab) | ||||
|         if sgd is None: | ||||
|             sgd = self.create_optimizer() | ||||
|         return sgd | ||||
|  | @ -153,7 +158,8 @@ class Pipe(object): | |||
|         serialize["cfg"] = lambda: srsly.json_dumps(self.cfg) | ||||
|         if self.model not in (True, False, None): | ||||
|             serialize["model"] = self.model.to_bytes | ||||
|         serialize["vocab"] = self.vocab.to_bytes | ||||
|         if hasattr(self, "vocab"): | ||||
|             serialize["vocab"] = self.vocab.to_bytes | ||||
|         exclude = util.get_serialization_exclude(serialize, exclude, kwargs) | ||||
|         return util.to_bytes(serialize, exclude) | ||||
| 
 | ||||
|  | @ -173,7 +179,8 @@ class Pipe(object): | |||
| 
 | ||||
|         deserialize = OrderedDict() | ||||
|         deserialize["cfg"] = lambda b: self.cfg.update(srsly.json_loads(b)) | ||||
|         deserialize["vocab"] = lambda b: self.vocab.from_bytes(b) | ||||
|         if hasattr(self, "vocab"): | ||||
|             deserialize["vocab"] = lambda b: self.vocab.from_bytes(b) | ||||
|         deserialize["model"] = load_model | ||||
|         exclude = util.get_serialization_exclude(deserialize, exclude, kwargs) | ||||
|         util.from_bytes(bytes_data, deserialize, exclude) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user