mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-26 21:51:24 +03:00 
			
		
		
		
	Refactor language update (#4316)
* refactor: separate formatting docs and golds in Language.update * fix return typo
This commit is contained in:
		
							parent
							
								
									105a91975b
								
							
						
					
					
						commit
						b408b5b29e
					
				|  | @ -449,29 +449,9 @@ class Language(object): | |||
|     def make_doc(self, text): | ||||
|         return self.tokenizer(text) | ||||
| 
 | ||||
|     def update(self, docs, golds, drop=0.0, sgd=None, losses=None, component_cfg=None): | ||||
|         """Update the models in the pipeline. | ||||
| 
 | ||||
|         docs (iterable): A batch of `Doc` objects. | ||||
|         golds (iterable): A batch of `GoldParse` objects. | ||||
|         drop (float): The droput rate. | ||||
|         sgd (callable): An optimizer. | ||||
|         losses (dict): Dictionary to update with the loss, keyed by component. | ||||
|         component_cfg (dict): Config parameters for specific pipeline | ||||
|             components, keyed by component name. | ||||
| 
 | ||||
|         DOCS: https://spacy.io/api/language#update | ||||
|         """ | ||||
|     def _format_docs_and_golds(self, docs, golds): | ||||
|         """Format golds and docs before update models.""" | ||||
|         expected_keys = ("words", "tags", "heads", "deps", "entities", "cats", "links") | ||||
|         if len(docs) != len(golds): | ||||
|             raise IndexError(Errors.E009.format(n_docs=len(docs), n_golds=len(golds))) | ||||
|         if len(docs) == 0: | ||||
|             return | ||||
|         if sgd is None: | ||||
|             if self._optimizer is None: | ||||
|                 self._optimizer = create_default_optimizer(Model.ops) | ||||
|             sgd = self._optimizer | ||||
|         # Allow dict of args to GoldParse, instead of GoldParse objects. | ||||
|         gold_objs = [] | ||||
|         doc_objs = [] | ||||
|         for doc, gold in zip(docs, golds): | ||||
|  | @ -485,8 +465,32 @@ class Language(object): | |||
|                 gold = GoldParse(doc, **gold) | ||||
|             doc_objs.append(doc) | ||||
|             gold_objs.append(gold) | ||||
|         golds = gold_objs | ||||
|         docs = doc_objs | ||||
| 
 | ||||
|         return doc_objs, gold_objs | ||||
| 
 | ||||
|     def update(self, docs, golds, drop=0.0, sgd=None, losses=None, component_cfg=None): | ||||
|         """Update the models in the pipeline. | ||||
| 
 | ||||
|         docs (iterable): A batch of `Doc` objects. | ||||
|         golds (iterable): A batch of `GoldParse` objects. | ||||
|         drop (float): The droput rate. | ||||
|         sgd (callable): An optimizer. | ||||
|         losses (dict): Dictionary to update with the loss, keyed by component. | ||||
|         component_cfg (dict): Config parameters for specific pipeline | ||||
|             components, keyed by component name. | ||||
| 
 | ||||
|         DOCS: https://spacy.io/api/language#update | ||||
|         """ | ||||
|         if len(docs) != len(golds): | ||||
|             raise IndexError(Errors.E009.format(n_docs=len(docs), n_golds=len(golds))) | ||||
|         if len(docs) == 0: | ||||
|             return | ||||
|         if sgd is None: | ||||
|             if self._optimizer is None: | ||||
|                 self._optimizer = create_default_optimizer(Model.ops) | ||||
|             sgd = self._optimizer | ||||
|         # Allow dict of args to GoldParse, instead of GoldParse objects. | ||||
|         docs, golds = self._format_docs_and_golds(docs, golds) | ||||
|         grads = {} | ||||
| 
 | ||||
|         def get_grads(W, dW, key=None): | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user