mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-30 23:47:31 +03:00 
			
		
		
		
	It works!
Was missing the serialization-related code from biaffine.
This commit is contained in:
		
							parent
							
								
									ba1bf8ae72
								
							
						
					
					
						commit
						bd17c38b74
					
				|  | @ -44,8 +44,9 @@ def build_wl_coref_model( | |||
|             }, | ||||
|         ) | ||||
| 
 | ||||
|         coref_model = tok2vec >> coref_clusterer | ||||
|     return coref_model | ||||
|         model = tok2vec >> coref_clusterer | ||||
|         model.set_ref("coref_clusterer", coref_clusterer) | ||||
|     return model | ||||
| 
 | ||||
| 
 | ||||
| def coref_init(model: Model, X=None, Y=None): | ||||
|  |  | |||
|  | @ -6,6 +6,7 @@ from thinc.api import Model, Config, Optimizer | |||
| from thinc.api import set_dropout_rate, to_categorical | ||||
| from itertools import islice | ||||
| from statistics import mean | ||||
| import srsly | ||||
| 
 | ||||
| from .trainable_pipe import TrainablePipe | ||||
| from ..language import Language | ||||
|  | @ -13,7 +14,7 @@ from ..training import Example, validate_examples, validate_get_examples | |||
| from ..errors import Errors | ||||
| from ..tokens import Doc | ||||
| from ..vocab import Vocab | ||||
| from ..util import registry | ||||
| from ..util import registry, from_disk, from_bytes | ||||
| 
 | ||||
| from ..ml.models.coref_util import ( | ||||
|     create_gold_scores, | ||||
|  | @ -316,3 +317,57 @@ class CoreferenceResolver(TrainablePipe): | |||
| 
 | ||||
|         assert len(X) > 0, Errors.E923.format(name=self.name) | ||||
|         self.model.initialize(X=X, Y=Y) | ||||
| 
 | ||||
|         # Store the input dimensionality. nI and nO are not stored explicitly | ||||
|         # for PyTorch models. This makes it tricky to reconstruct the model | ||||
|         # during deserialization. So, besides storing the labels, we also | ||||
|         # store the number of inputs. | ||||
|         coref_clusterer = self.model.get_ref("coref_clusterer") | ||||
|         self.cfg["nI"] = coref_clusterer.get_dim("nI") | ||||
| 
 | ||||
|     def from_bytes(self, bytes_data, *, exclude=tuple()): | ||||
|         deserializers = { | ||||
|             "cfg": lambda b: self.cfg.update(srsly.json_loads(b)), | ||||
|             "vocab": lambda b: self.vocab.from_bytes(b, exclude=exclude), | ||||
|         } | ||||
|         from_bytes(bytes_data, deserializers, exclude) | ||||
| 
 | ||||
|         self._initialize_from_disk() | ||||
| 
 | ||||
|         model_deserializers = { | ||||
|             "model": lambda b: self.model.from_bytes(b), | ||||
|         } | ||||
|         from_bytes(bytes_data, model_deserializers, exclude) | ||||
| 
 | ||||
|         return self | ||||
| 
 | ||||
|     def from_disk(self, path, exclude=tuple()): | ||||
|         def load_model(p): | ||||
|             try: | ||||
|                 with open(p, "rb") as mfile: | ||||
|                     self.model.from_bytes(mfile.read()) | ||||
|             except AttributeError: | ||||
|                 raise ValueError(Errors.E149) from None | ||||
| 
 | ||||
|         deserializers = { | ||||
|             "cfg": lambda p: self.cfg.update(srsly.read_json(p)), | ||||
|             "vocab": lambda p: self.vocab.from_disk(p, exclude=exclude), | ||||
|         } | ||||
|         from_disk(path, deserializers, exclude) | ||||
| 
 | ||||
|         self._initialize_from_disk() | ||||
| 
 | ||||
|         model_deserializers = { | ||||
|             "model": load_model, | ||||
|         } | ||||
|         from_disk(path, model_deserializers, exclude) | ||||
| 
 | ||||
|         return self | ||||
| 
 | ||||
|     def _initialize_from_disk(self): | ||||
|         # The PyTorch model is constructed lazily, so we need to | ||||
|         # explicitly initialize the model before deserialization. | ||||
|         model = self.model.get_ref("coref_clusterer") | ||||
|         if model.has_dim("nI") is None: | ||||
|             model.set_dim("nI", self.cfg["nI"]) | ||||
|         self.model.initialize() | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user