mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-04 12:20:20 +03:00
Sync names and listeners for most recently modified pipeline
This commit is contained in:
parent
9c1e33a9af
commit
180d15b39c
|
@ -716,6 +716,11 @@ class Language:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
pipe = source.get_pipe(source_name)
|
pipe = source.get_pipe(source_name)
|
||||||
|
# There is no actual solution here. Either the component has the right
|
||||||
|
# name for the source pipeline or the component has the right name for
|
||||||
|
# the current pipeline. This prioritizes the current pipeline.
|
||||||
|
if hasattr(pipe, "name"):
|
||||||
|
pipe.name = name
|
||||||
# Make sure the source config is interpolated so we don't end up with
|
# Make sure the source config is interpolated so we don't end up with
|
||||||
# orphaned variables in our final config
|
# orphaned variables in our final config
|
||||||
source_config = source.config.interpolate()
|
source_config = source.config.interpolate()
|
||||||
|
@ -793,6 +798,7 @@ class Language:
|
||||||
pipe_index = self._get_pipe_index(before, after, first, last)
|
pipe_index = self._get_pipe_index(before, after, first, last)
|
||||||
self._pipe_meta[name] = self.get_factory_meta(factory_name)
|
self._pipe_meta[name] = self.get_factory_meta(factory_name)
|
||||||
self._components.insert(pipe_index, (name, pipe_component))
|
self._components.insert(pipe_index, (name, pipe_component))
|
||||||
|
self._link_components()
|
||||||
return pipe_component
|
return pipe_component
|
||||||
|
|
||||||
def _get_pipe_index(
|
def _get_pipe_index(
|
||||||
|
@ -928,6 +934,10 @@ class Language:
|
||||||
if old_name in self._config["initialize"]["components"]:
|
if old_name in self._config["initialize"]["components"]:
|
||||||
init_cfg = self._config["initialize"]["components"].pop(old_name)
|
init_cfg = self._config["initialize"]["components"].pop(old_name)
|
||||||
self._config["initialize"]["components"][new_name] = init_cfg
|
self._config["initialize"]["components"][new_name] = init_cfg
|
||||||
|
pipe = self.get_pipe(new_name)
|
||||||
|
if hasattr(pipe, "name"):
|
||||||
|
pipe.name = new_name
|
||||||
|
self._link_components()
|
||||||
|
|
||||||
def remove_pipe(self, name: str) -> Tuple[str, PipeCallable]:
|
def remove_pipe(self, name: str) -> Tuple[str, PipeCallable]:
|
||||||
"""Remove a component from the pipeline.
|
"""Remove a component from the pipeline.
|
||||||
|
@ -951,6 +961,7 @@ class Language:
|
||||||
# Make sure the name is also removed from the set of disabled components
|
# Make sure the name is also removed from the set of disabled components
|
||||||
if name in self.disabled:
|
if name in self.disabled:
|
||||||
self._disabled.remove(name)
|
self._disabled.remove(name)
|
||||||
|
self._link_components()
|
||||||
return removed
|
return removed
|
||||||
|
|
||||||
def disable_pipe(self, name: str) -> None:
|
def disable_pipe(self, name: str) -> None:
|
||||||
|
@ -1676,8 +1687,16 @@ class Language:
|
||||||
# The problem is we need to do it during deserialization...And the
|
# The problem is we need to do it during deserialization...And the
|
||||||
# components don't receive the pipeline then. So this does have to be
|
# components don't receive the pipeline then. So this does have to be
|
||||||
# here :(
|
# here :(
|
||||||
|
# First, fix up all the internal component names in case they have
|
||||||
|
# gotten out of sync due to sourcing components from different
|
||||||
|
# pipelines, since find_listeners uses proc2.name for the listener
|
||||||
|
# map.
|
||||||
|
for name, proc in self.pipeline:
|
||||||
|
if hasattr(proc, "name"):
|
||||||
|
proc.name = name
|
||||||
for i, (name1, proc1) in enumerate(self.pipeline):
|
for i, (name1, proc1) in enumerate(self.pipeline):
|
||||||
if isinstance(proc1, ty.ListenedToComponent):
|
if isinstance(proc1, ty.ListenedToComponent):
|
||||||
|
proc1.listener_map = {}
|
||||||
for name2, proc2 in self.pipeline[i + 1 :]:
|
for name2, proc2 in self.pipeline[i + 1 :]:
|
||||||
proc1.find_listeners(proc2)
|
proc1.find_listeners(proc2)
|
||||||
|
|
||||||
|
@ -1811,6 +1830,7 @@ class Language:
|
||||||
raw_config=raw_config,
|
raw_config=raw_config,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
assert "source" in pipe_cfg
|
||||||
# We need the sourced components to reference the same
|
# We need the sourced components to reference the same
|
||||||
# vocab without modifying the current vocab state **AND**
|
# vocab without modifying the current vocab state **AND**
|
||||||
# we still want to load the source model vectors to perform
|
# we still want to load the source model vectors to perform
|
||||||
|
@ -1830,6 +1850,10 @@ class Language:
|
||||||
source_name = pipe_cfg.get("component", pipe_name)
|
source_name = pipe_cfg.get("component", pipe_name)
|
||||||
listeners_replaced = False
|
listeners_replaced = False
|
||||||
if "replace_listeners" in pipe_cfg:
|
if "replace_listeners" in pipe_cfg:
|
||||||
|
# Make sure that the listened-to component has the
|
||||||
|
# state of the source pipeline listener map so that the
|
||||||
|
# replace_listeners method below works as intended.
|
||||||
|
source_nlps[model]._link_components()
|
||||||
for name, proc in source_nlps[model].pipeline:
|
for name, proc in source_nlps[model].pipeline:
|
||||||
if source_name in getattr(proc, "listening_components", []):
|
if source_name in getattr(proc, "listening_components", []):
|
||||||
source_nlps[model].replace_listeners(
|
source_nlps[model].replace_listeners(
|
||||||
|
@ -1841,6 +1865,8 @@ class Language:
|
||||||
nlp.add_pipe(
|
nlp.add_pipe(
|
||||||
source_name, source=source_nlps[model], name=pipe_name
|
source_name, source=source_nlps[model], name=pipe_name
|
||||||
)
|
)
|
||||||
|
# At this point after nlp.add_pipe, the listener map
|
||||||
|
# corresponds to the new pipeline.
|
||||||
if model not in source_nlp_vectors_hashes:
|
if model not in source_nlp_vectors_hashes:
|
||||||
source_nlp_vectors_hashes[model] = hash(
|
source_nlp_vectors_hashes[model] = hash(
|
||||||
source_nlps[model].vocab.vectors.to_bytes(
|
source_nlps[model].vocab.vectors.to_bytes(
|
||||||
|
|
|
@ -67,7 +67,8 @@ def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language":
|
||||||
with nlp.select_pipes(enable=resume_components):
|
with nlp.select_pipes(enable=resume_components):
|
||||||
logger.info("Resuming training for: %s", resume_components)
|
logger.info("Resuming training for: %s", resume_components)
|
||||||
nlp.resume_training(sgd=optimizer)
|
nlp.resume_training(sgd=optimizer)
|
||||||
# Make sure that listeners are defined before initializing further
|
# Make sure that internal component names are synced and listeners are
|
||||||
|
# defined before initializing further
|
||||||
nlp._link_components()
|
nlp._link_components()
|
||||||
with nlp.select_pipes(disable=[*frozen_components, *resume_components]):
|
with nlp.select_pipes(disable=[*frozen_components, *resume_components]):
|
||||||
if T["max_epochs"] == -1:
|
if T["max_epochs"] == -1:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user