warn when frozen components break listener pattern (#6766)

* warn when frozen components break listener pattern

* few notes in the documentation

* update arg name

* formatting

* cleanup

* specify listeners return type
This commit is contained in:
Sofie Van Landeghem 2021-01-20 01:12:35 +01:00 committed by GitHub
parent 88acbfc050
commit 57640aa838
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 56 additions and 18 deletions

View File

@ -1496,8 +1496,7 @@ class Language:
for i, (name1, proc1) in enumerate(self.pipeline):
if hasattr(proc1, "find_listeners"):
for name2, proc2 in self.pipeline[i + 1 :]:
if isinstance(getattr(proc2, "model", None), Model):
proc1.find_listeners(proc2.model)
proc1.find_listeners(proc2)
@classmethod
def from_config(

View File

@ -62,28 +62,42 @@ class Tok2Vec(TrainablePipe):
self.vocab = vocab
self.model = model
self.name = name
self.listeners = []
self.listener_map = {}
self.cfg = {}
def add_listener(self, listener: "Tok2VecListener") -> None:
"""Add a listener for a downstream component. Usually internals."""
self.listeners.append(listener)
@property
def listeners(self) -> List["Tok2VecListener"]:
"""RETURNS (List[Tok2VecListener]): The listener models listening to this
component. Usually internals.
"""
return [m for c in self.listening_components for m in self.listener_map[c]]
def find_listeners(self, model: Model) -> None:
"""Walk over a model, looking for layers that are Tok2vecListener
subclasses that have an upstream_name that matches this component.
Listeners can also set their upstream_name attribute to the wildcard
string '*' to match any `Tok2Vec`.
@property
def listening_components(self) -> List[str]:
"""RETURNS (List[str]): The downstream components listening to this
component. Usually internals.
"""
return list(self.listener_map.keys())
def add_listener(self, listener: "Tok2VecListener", component_name: str) -> None:
"""Add a listener for a downstream component. Usually internals."""
self.listener_map.setdefault(component_name, [])
self.listener_map[component_name].append(listener)
def find_listeners(self, component) -> None:
"""Walk over a model of a processing component, looking for layers that
are Tok2vecListener subclasses that have an upstream_name that matches
this component. Listeners can also set their upstream_name attribute to
the wildcard string '*' to match any `Tok2Vec`.
You're unlikely to ever need multiple `Tok2Vec` components, so it's
fine to leave your listeners upstream_name on '*'.
"""
for node in model.walk():
if isinstance(node, Tok2VecListener) and node.upstream_name in (
"*",
self.name,
):
self.add_listener(node)
names = ("*", self.name)
if isinstance(getattr(component, "model", None), Model):
for node in component.model.walk():
if isinstance(node, Tok2VecListener) and node.upstream_name in names:
self.add_listener(node, component.name)
def __call__(self, doc: Doc) -> Doc:
"""Add context-sensitive embeddings to the Doc.tensor attribute, allowing

View File

@ -66,6 +66,20 @@ def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language":
with nlp.select_pipes(disable=[*frozen_components, *resume_components]):
nlp.initialize(lambda: train_corpus(nlp), sgd=optimizer)
logger.info(f"Initialized pipeline components: {nlp.pipe_names}")
# Detect components with listeners that are not frozen consistently
for name, proc in nlp.pipeline:
if getattr(proc, "listening_components", None):
for listener in proc.listening_components:
if listener in frozen_components and name not in frozen_components:
logger.warn(f"Component '{name}' will be (re)trained, but the "
f"'{listener}' depends on it and is frozen. This means "
f"that the performance of the '{listener}' will be degraded. "
f"You should either freeze both, or neither of the two.")
if listener not in frozen_components and name in frozen_components:
logger.warn(f"Component '{listener}' will be (re)trained, but it needs the "
f"'{name}' which is frozen. "
f"You should either freeze both, or neither of the two.")
return nlp

View File

@ -400,7 +400,8 @@ vectors available otherwise, it won't be able to make the same predictions.
> ```
>
> By default, sourced components will be updated with your data during training.
> If you want to preserve the component as-is, you can "freeze" it:
> If you want to preserve the component as-is, you can "freeze" it if the pipeline
> is not using a shared `Tok2Vec` layer:
>
> ```ini
> [training]

View File

@ -419,6 +419,16 @@ pipeline = ["parser", "ner", "textcat", "custom"]
frozen_components = ["parser", "custom"]
```
<Infobox variant="warning" title="Shared Tok2Vec layer">
When the components in your pipeline
[share an embedding layer](/usage/embeddings-transformers#embedding-layers), the
**performance** of your frozen component will be **degraded** if you continue training
other layers with the same underlying `Tok2Vec` instance. As a rule of thumb,
ensure that your frozen components are truly **independent** in the pipeline.
</Infobox>
### Using registered functions {#config-functions}
The training configuration defined in the config file doesn't have to only