mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-14 18:22:27 +03:00
link components in add and remove pipe
This commit is contained in:
parent
9b7a59c325
commit
fb70f8813d
|
@ -793,6 +793,7 @@ class Language:
|
||||||
pipe_index = self._get_pipe_index(before, after, first, last)
|
pipe_index = self._get_pipe_index(before, after, first, last)
|
||||||
self._pipe_meta[name] = self.get_factory_meta(factory_name)
|
self._pipe_meta[name] = self.get_factory_meta(factory_name)
|
||||||
self._components.insert(pipe_index, (name, pipe_component))
|
self._components.insert(pipe_index, (name, pipe_component))
|
||||||
|
self._link_components()
|
||||||
return pipe_component
|
return pipe_component
|
||||||
|
|
||||||
def _get_pipe_index(
|
def _get_pipe_index(
|
||||||
|
@ -951,6 +952,7 @@ class Language:
|
||||||
# Make sure the name is also removed from the set of disabled components
|
# Make sure the name is also removed from the set of disabled components
|
||||||
if name in self.disabled:
|
if name in self.disabled:
|
||||||
self._disabled.remove(name)
|
self._disabled.remove(name)
|
||||||
|
self._link_components()
|
||||||
return removed
|
return removed
|
||||||
|
|
||||||
def disable_pipe(self, name: str) -> None:
|
def disable_pipe(self, name: str) -> None:
|
||||||
|
@ -1310,7 +1312,6 @@ class Language:
|
||||||
if pretrain_cfg:
|
if pretrain_cfg:
|
||||||
P = registry.resolve(pretrain_cfg, schema=ConfigSchemaPretrain)
|
P = registry.resolve(pretrain_cfg, schema=ConfigSchemaPretrain)
|
||||||
init_tok2vec(self, P, I)
|
init_tok2vec(self, P, I)
|
||||||
self._link_components()
|
|
||||||
self._optimizer = sgd
|
self._optimizer = sgd
|
||||||
if sgd is not None:
|
if sgd is not None:
|
||||||
self._optimizer = sgd
|
self._optimizer = sgd
|
||||||
|
@ -1678,6 +1679,7 @@ class Language:
|
||||||
# here :(
|
# here :(
|
||||||
for i, (name1, proc1) in enumerate(self.pipeline):
|
for i, (name1, proc1) in enumerate(self.pipeline):
|
||||||
if isinstance(proc1, ty.ListenedToComponent):
|
if isinstance(proc1, ty.ListenedToComponent):
|
||||||
|
proc1.listener_map = {}
|
||||||
for name2, proc2 in self.pipeline[i + 1 :]:
|
for name2, proc2 in self.pipeline[i + 1 :]:
|
||||||
proc1.find_listeners(proc2)
|
proc1.find_listeners(proc2)
|
||||||
|
|
||||||
|
@ -1811,6 +1813,7 @@ class Language:
|
||||||
raw_config=raw_config,
|
raw_config=raw_config,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
assert "source" in pipe_cfg
|
||||||
# We need the sourced components to reference the same
|
# We need the sourced components to reference the same
|
||||||
# vocab without modifying the current vocab state **AND**
|
# vocab without modifying the current vocab state **AND**
|
||||||
# we still want to load the source model vectors to perform
|
# we still want to load the source model vectors to perform
|
||||||
|
|
|
@ -189,10 +189,10 @@ def test_tok2vec_listener(with_vectors):
|
||||||
tagger.add_label(tag)
|
tagger.add_label(tag)
|
||||||
|
|
||||||
# Check that the Tok2Vec component finds it listeners
|
# Check that the Tok2Vec component finds it listeners
|
||||||
assert tok2vec.listeners == []
|
|
||||||
optimizer = nlp.initialize(lambda: train_examples)
|
|
||||||
assert tok2vec.listeners == [tagger_tok2vec]
|
assert tok2vec.listeners == [tagger_tok2vec]
|
||||||
|
|
||||||
|
# Initialize and train
|
||||||
|
optimizer = nlp.initialize(lambda: train_examples)
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
losses = {}
|
losses = {}
|
||||||
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
||||||
|
@ -540,3 +540,17 @@ def test_tok2vec_listeners_textcat():
|
||||||
assert cats1["imperative"] < 0.9
|
assert cats1["imperative"] < 0.9
|
||||||
assert [t.tag_ for t in docs[0]] == ["V", "J", "N"]
|
assert [t.tag_ for t in docs[0]] == ["V", "J", "N"]
|
||||||
assert [t.tag_ for t in docs[1]] == ["N", "V", "J", "N"]
|
assert [t.tag_ for t in docs[1]] == ["N", "V", "J", "N"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_tok2vec_listener_source_link():
|
||||||
|
orig_config = Config().from_str(cfg_string_multi)
|
||||||
|
nlp1 = util.load_model_from_config(orig_config, auto_fill=True, validate=True)
|
||||||
|
assert list(nlp1.get_pipe("tok2vec").listener_map.keys()) == ["tagger", "ner"]
|
||||||
|
|
||||||
|
nlp2 = English()
|
||||||
|
nlp2.add_pipe("tok2vec", source=nlp1)
|
||||||
|
assert nlp2.get_pipe("tok2vec").listener_map == {}
|
||||||
|
nlp2.add_pipe("tagger", source=nlp1)
|
||||||
|
assert list(nlp2.get_pipe("tok2vec").listener_map.keys()) == ["tagger"]
|
||||||
|
nlp2.add_pipe("ner", source=nlp1)
|
||||||
|
assert list(nlp2.get_pipe("tok2vec").listener_map.keys()) == ["tagger", "ner"]
|
Loading…
Reference in New Issue
Block a user