mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 07:57:35 +03:00 
			
		
		
		
	Move replacement logic to Language.from_config
This commit is contained in:
		
							parent
							
								
									0f3e3eedc2
								
							
						
					
					
						commit
						325f47500d
					
				|  | @ -1629,6 +1629,7 @@ class Language: | |||
|         # Later we replace the component config with the raw config again. | ||||
|         interpolated = filled.interpolate() if not filled.is_interpolated else filled | ||||
|         pipeline = interpolated.get("components", {}) | ||||
|         sourced = util.get_sourced_components(interpolated) | ||||
|         # If components are loaded from a source (existing models), we cache | ||||
|         # them here so they're only loaded once | ||||
|         source_nlps = {} | ||||
|  | @ -1671,6 +1672,17 @@ class Language: | |||
|                 raise ValueError( | ||||
|                     Errors.E942.format(name="pipeline_creation", value=type(nlp)) | ||||
|                 ) | ||||
|         # Detect components with listeners that are not frozen consistently | ||||
|         for name, proc in nlp.pipeline: | ||||
|             if getattr(proc, "listening_components", None):  # e.g. tok2vec/transformer | ||||
|                 for listener in proc.listening_components: | ||||
|                     # If it's a component sourced from another pipeline, we check if | ||||
|                     # the tok2vec listeners should be replaced with standalone tok2vec | ||||
|                     # models (e.g. so component can be frozen without its performance | ||||
|                     # degrading when other components/tok2vec are updated) | ||||
|                     paths = sourced.get(listener, {}).get("replace_listeners", []) | ||||
|                     if paths: | ||||
|                         nlp.replace_listeners(name, listener, paths) | ||||
|         return nlp | ||||
| 
 | ||||
|     def replace_listeners( | ||||
|  | @ -1744,6 +1756,7 @@ class Language: | |||
|             # Go over the listener layers and replace them | ||||
|             for listener in pipe_listeners: | ||||
|                 util.replace_model_node(pipe.model, listener, tok2vec.model.copy()) | ||||
|                 tok2vec.remove_listener(listener, pipe_name) | ||||
| 
 | ||||
|     def to_disk( | ||||
|         self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList() | ||||
|  |  | |||
|  | @ -14,7 +14,8 @@ from ..vectors import Vectors | |||
| from ..errors import Errors, Warnings | ||||
| from ..schemas import ConfigSchemaTraining | ||||
| from ..util import registry, load_model_from_config, resolve_dot_names, logger | ||||
| from ..util import load_model, ensure_path, OOV_RANK, DEFAULT_OOV_PROB | ||||
| from ..util import load_model, ensure_path, get_sourced_components | ||||
| from ..util import OOV_RANK, DEFAULT_OOV_PROB | ||||
| 
 | ||||
| if TYPE_CHECKING: | ||||
|     from ..language import Language  # noqa: F401 | ||||
|  | @ -70,14 +71,7 @@ def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language": | |||
|     for name, proc in nlp.pipeline: | ||||
|         if getattr(proc, "listening_components", None):  # e.g. tok2vec/transformer | ||||
|             for listener in proc.listening_components: | ||||
|                 # If it's a component sourced from another pipeline, we check if | ||||
|                 # the tok2vec listeners should be replaced with standalone tok2vec | ||||
|                 # models (e.g. so component can be frozen without its performance | ||||
|                 # degrading when other components/tok2vec are updated) | ||||
|                 paths = sourced.get(listener, {}).get("replace_listeners", []) | ||||
|                 if paths: | ||||
|                     nlp.replace_listeners(name, listener, paths) | ||||
|                 elif listener in frozen_components and name not in frozen_components: | ||||
|                 if listener in frozen_components and name not in frozen_components: | ||||
|                     logger.warning(Warnings.W087.format(name=name, listener=listener)) | ||||
|                 # We always check this regardless, in case user freezes tok2vec | ||||
|                 if listener not in frozen_components and name in frozen_components: | ||||
|  | @ -181,20 +175,6 @@ def init_tok2vec( | |||
|     return False | ||||
| 
 | ||||
| 
 | ||||
| def get_sourced_components( | ||||
|     config: Union[Dict[str, Any], Config] | ||||
| ) -> Dict[str, Dict[str, Any]]: | ||||
|     """RETURNS (List[str]): All sourced components in the original config, | ||||
|     e.g. {"source": "en_core_web_sm"}. If the config contains a key | ||||
|     "factory", we assume it refers to a component factory. | ||||
|     """ | ||||
|     return { | ||||
|         name: cfg | ||||
|         for name, cfg in config.get("components", {}).items() | ||||
|         if "factory" not in cfg and "source" in cfg | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| def convert_vectors( | ||||
|     nlp: "Language", | ||||
|     vectors_loc: Optional[Path], | ||||
|  |  | |||
|  | @ -434,6 +434,20 @@ def load_model_from_config( | |||
|     return nlp | ||||
| 
 | ||||
| 
 | ||||
| def get_sourced_components( | ||||
|     config: Union[Dict[str, Any], Config] | ||||
| ) -> Dict[str, Dict[str, Any]]: | ||||
|     """RETURNS (List[str]): All sourced components in the original config, | ||||
|     e.g. {"source": "en_core_web_sm"}. If the config contains a key | ||||
|     "factory", we assume it refers to a component factory. | ||||
|     """ | ||||
|     return { | ||||
|         name: cfg | ||||
|         for name, cfg in config.get("components", {}).items() | ||||
|         if "factory" not in cfg and "source" in cfg | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| def resolve_dot_names(config: Config, dot_names: List[Optional[str]]) -> Tuple[Any]: | ||||
|     """Resolve one or more "dot notation" names, e.g. corpora.train. | ||||
|     The paths could point anywhere into the config, so we don't know which | ||||
|  | @ -1480,5 +1494,6 @@ def _pipe(docs, proc, name, default_error_handler, kwargs): | |||
| def raise_error(proc_name, proc, docs, e): | ||||
|     raise e | ||||
| 
 | ||||
| 
 | ||||
| def ignore_error(proc_name, proc, docs, e): | ||||
|     pass | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user