mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 16:07:41 +03:00 
			
		
		
		
	is_trainable method
This commit is contained in:
		
							parent
							
								
									dc06912c76
								
							
						
					
					
						commit
						4e3ace4b8c
					
				|  | @ -1091,7 +1091,8 @@ class Language: | |||
|             for name, proc in self.pipeline: | ||||
|                 if ( | ||||
|                     name not in exclude | ||||
|                     and hasattr(proc, "model") | ||||
|                     and hasattr(proc, "is_trainable") | ||||
|                     and proc.is_trainable() | ||||
|                     and proc.model not in (True, False, None) | ||||
|                 ): | ||||
|                     proc.finish_update(sgd) | ||||
|  | @ -1297,7 +1298,9 @@ class Language: | |||
|         for name, pipe in self.pipeline: | ||||
|             kwargs = component_cfg.get(name, {}) | ||||
|             kwargs.setdefault("batch_size", batch_size) | ||||
|             if not hasattr(pipe, "pipe"): | ||||
|             # non-trainable components may have a pipe() implementation that refers to dummy | ||||
|             # predict and set_annotations methods | ||||
|             if not hasattr(pipe, "pipe") or not hasattr(pipe, "is_trainable") or not pipe.is_trainable(): | ||||
|                 docs = _pipe(docs, pipe, kwargs) | ||||
|             else: | ||||
|                 docs = pipe.pipe(docs, **kwargs) | ||||
|  |  | |||
|  | @ -2,8 +2,9 @@ from typing import Optional, Union, List, Dict, Tuple, Iterable, Any, Callable | |||
| from collections import defaultdict | ||||
| from pathlib import Path | ||||
| import srsly | ||||
| from spacy.training import Example | ||||
| 
 | ||||
| from .pipe import Pipe | ||||
| from ..training import Example | ||||
| from ..language import Language | ||||
| from ..errors import Errors | ||||
| from ..util import ensure_path, to_disk, from_disk, SimpleFrozenList | ||||
|  | @ -51,7 +52,7 @@ def make_entity_ruler( | |||
|     ) | ||||
| 
 | ||||
| 
 | ||||
| class EntityRuler: | ||||
| class EntityRuler(Pipe): | ||||
|     """The EntityRuler lets you add spans to the `Doc.ents` using token-based | ||||
|     rules or exact phrase matches. It can be combined with the statistical | ||||
|     `EntityRecognizer` to boost accuracy, or used on its own to implement a | ||||
|  | @ -134,7 +135,6 @@ class EntityRuler: | |||
| 
 | ||||
|         DOCS: https://nightly.spacy.io/api/entityruler#call | ||||
|         """ | ||||
|         self._require_patterns() | ||||
|         matches = list(self.matcher(doc)) + list(self.phrase_matcher(doc)) | ||||
|         matches = set( | ||||
|             [(m_id, start, end) for m_id, start, end in matches if start != end] | ||||
|  | @ -315,11 +315,6 @@ class EntityRuler: | |||
|         self.phrase_patterns = defaultdict(list) | ||||
|         self._ent_ids = defaultdict(dict) | ||||
| 
 | ||||
|     def _require_patterns(self) -> None: | ||||
|         """Raise an error if the component has no patterns.""" | ||||
|         if not self.patterns or list(self.patterns) == [""]: | ||||
|             raise ValueError(Errors.E900.format(name=self.name)) | ||||
| 
 | ||||
|     def _split_label(self, label: str) -> Tuple[str, str]: | ||||
|         """Split Entity label into ent_label and ent_id if it contains self.ent_id_sep | ||||
| 
 | ||||
|  | @ -348,6 +343,12 @@ class EntityRuler: | |||
|         validate_examples(examples, "EntityRuler.score") | ||||
|         return Scorer.score_spans(examples, "ents", **kwargs) | ||||
| 
 | ||||
|     def predict(self, docs): | ||||
|         pass | ||||
| 
 | ||||
|     def set_annotations(self, docs, scores): | ||||
|         pass | ||||
| 
 | ||||
|     def from_bytes( | ||||
|         self, patterns_bytes: bytes, *, exclude: Iterable[str] = SimpleFrozenList() | ||||
|     ) -> "EntityRuler": | ||||
|  |  | |||
|  | @ -228,6 +228,9 @@ cdef class Pipe: | |||
|     def is_resizable(self): | ||||
|         return hasattr(self, "model") and "resize_output" in self.model.attrs | ||||
| 
 | ||||
|     def is_trainable(self): | ||||
|         return hasattr(self, "model") and isinstance(self.model, Model) | ||||
| 
 | ||||
|     def set_output(self, nO): | ||||
|         if self.is_resizable(): | ||||
|             self.model.attrs["resize_output"](self.model, nO) | ||||
|  |  | |||
|  | @ -17,8 +17,12 @@ def console_logger(progress_bar: bool = False): | |||
|         nlp: "Language", stdout: IO = sys.stdout, stderr: IO = sys.stderr | ||||
|     ) -> Tuple[Callable[[Optional[Dict[str, Any]]], None], Callable[[], None]]: | ||||
|         msg = Printer(no_print=True) | ||||
|         # we assume here that only components are enabled that should be trained & logged | ||||
|         logged_pipes = nlp.pipe_names | ||||
|         # ensure that only trainable components are logged | ||||
|         logged_pipes = [ | ||||
|             name | ||||
|             for name, proc in nlp.pipeline | ||||
|             if hasattr(proc, "is_trainable") and proc.is_trainable() | ||||
|         ] | ||||
|         eval_frequency = nlp.config["training"]["eval_frequency"] | ||||
|         score_weights = nlp.config["training"]["score_weights"] | ||||
|         score_cols = [col for col, value in score_weights.items() if value is not None] | ||||
|  | @ -43,7 +47,7 @@ def console_logger(progress_bar: bool = False): | |||
|                 return | ||||
|             losses = [ | ||||
|                 "{0:.2f}".format(float(info["losses"][pipe_name])) | ||||
|                 for pipe_name in logged_pipes if pipe_name in info["losses"] | ||||
|                 for pipe_name in logged_pipes | ||||
|             ] | ||||
| 
 | ||||
|             scores = [] | ||||
|  |  | |||
|  | @ -181,7 +181,8 @@ def train_while_improving( | |||
|         for name, proc in nlp.pipeline: | ||||
|             if ( | ||||
|                 name not in exclude | ||||
|                 and hasattr(proc, "model") | ||||
|                 and hasattr(proc, "is_trainable") | ||||
|                 and proc.is_trainable() | ||||
|                 and proc.model not in (True, False, None) | ||||
|             ): | ||||
|                 proc.finish_update(optimizer) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user