@factories -> factory (#5801)

This commit is contained in:
Ines Montani 2020-07-22 17:29:31 +02:00 committed by GitHub
parent be476e495e
commit d0c6d1efc5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 12 additions and 12 deletions

View File

@ -526,7 +526,7 @@ def verify_config(nlp: Language) -> None:
# in config["nlp"]["pipeline"] instead?
for pipe_config in nlp.config["components"].values():
# We can't assume that the component name == the factory
factory = pipe_config["@factories"]
factory = pipe_config["factory"]
if factory == "textcat":
verify_textcat_config(nlp, pipe_config)

View File

@ -564,9 +564,9 @@ class Errors:
"into {values}, but found {value}.")
E983 = ("Invalid key for '{dict}': {key}. Available keys: "
"{keys}")
E984 = ("Invalid component config for '{name}': no @factories key "
E984 = ("Invalid component config for '{name}': no 'factory' key "
"specifying the registered function used to initialize the "
"component. For example, @factories = \"ner\" will use the 'ner' "
"component. For example, factory = \"ner\" will use the 'ner' "
"factory and all other settings in the block will be passed "
"to it as arguments.\n\n{config}")
E985 = ("Can't load model from config file: no 'nlp' section found.\n\n{config}")

View File

@ -185,7 +185,7 @@ class Language:
for pipe_name in self.pipe_names:
pipe_meta = self.get_pipe_meta(pipe_name)
pipe_config = self.get_pipe_config(pipe_name)
pipeline[pipe_name] = {"@factories": pipe_meta.factory, **pipe_config}
pipeline[pipe_name] = {"factory": pipe_meta.factory, **pipe_config}
self._config["nlp"]["pipeline"] = self.pipe_names
self._config["components"] = pipeline
if not srsly.is_json_serializable(self._config):
@ -491,7 +491,7 @@ class Language:
# pipeline component and why it failed, explain default config
resolved, filled = registry.resolve(cfg, validate=validate, overrides=overrides)
filled = filled[factory_name]
filled["@factories"] = factory_name
filled["factory"] = factory_name
self._pipe_configs[name] = filled
return resolved[factory_name]
@ -1283,12 +1283,12 @@ class Language:
if pipe_name not in pipeline:
opts = ", ".join(pipeline.keys())
raise ValueError(Errors.E956.format(name=pipe_name, opts=opts))
pipe_cfg = pipeline[pipe_name]
pipe_cfg = util.copy_config(pipeline[pipe_name])
if pipe_name not in disable:
if "@factories" not in pipe_cfg:
if "factory" not in pipe_cfg:
err = Errors.E984.format(name=pipe_name, config=pipe_cfg)
raise ValueError(err)
factory = pipe_cfg["@factories"]
factory = pipe_cfg.pop("factory")
# The pipe name (key in the config) here is the unique name of the
# component, not necessarily the factory
nlp.add_pipe(

View File

@ -20,7 +20,7 @@ pipeline = ["tok2vec", "tagger"]
[components]
[components.tok2vec]
@factories = "tok2vec"
factory = "tok2vec"
[components.tok2vec.model]
@architectures = "spacy.HashEmbedCNN.v1"
@ -34,7 +34,7 @@ subword_features = true
dropout = null
[components.tagger]
@factories = "tagger"
factory = "tagger"
[components.tagger.model]
@architectures = "spacy.Tagger.v1"
@ -245,7 +245,7 @@ def test_serialize_config_language_specific():
nlp.add_pipe(name, config={"foo": 100}, name="bar")
pipe_config = nlp.config["components"]["bar"]
assert pipe_config["foo"] == 100
assert pipe_config["@factories"] == name
assert pipe_config["factory"] == name
with make_tempdir() as d:
nlp.to_disk(d)
@ -255,7 +255,7 @@ def test_serialize_config_language_specific():
assert nlp2.get_pipe_meta("bar").factory == name
pipe_config = nlp2.config["components"]["bar"]
assert pipe_config["foo"] == 100
assert pipe_config["@factories"] == name
assert pipe_config["factory"] == name
config = Config().from_str(nlp2.config.to_str())
config["nlp"]["lang"] = "de"