@factories -> factory (#5801)

This commit is contained in:
Ines Montani 2020-07-22 17:29:31 +02:00 committed by GitHub
parent be476e495e
commit d0c6d1efc5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 12 additions and 12 deletions

View File

@ -526,7 +526,7 @@ def verify_config(nlp: Language) -> None:
# in config["nlp"]["pipeline"] instead? # in config["nlp"]["pipeline"] instead?
for pipe_config in nlp.config["components"].values(): for pipe_config in nlp.config["components"].values():
# We can't assume that the component name == the factory # We can't assume that the component name == the factory
factory = pipe_config["@factories"] factory = pipe_config["factory"]
if factory == "textcat": if factory == "textcat":
verify_textcat_config(nlp, pipe_config) verify_textcat_config(nlp, pipe_config)

View File

@ -564,9 +564,9 @@ class Errors:
"into {values}, but found {value}.") "into {values}, but found {value}.")
E983 = ("Invalid key for '{dict}': {key}. Available keys: " E983 = ("Invalid key for '{dict}': {key}. Available keys: "
"{keys}") "{keys}")
E984 = ("Invalid component config for '{name}': no @factories key " E984 = ("Invalid component config for '{name}': no 'factory' key "
"specifying the registered function used to initialize the " "specifying the registered function used to initialize the "
"component. For example, @factories = \"ner\" will use the 'ner' " "component. For example, factory = \"ner\" will use the 'ner' "
"factory and all other settings in the block will be passed " "factory and all other settings in the block will be passed "
"to it as arguments.\n\n{config}") "to it as arguments.\n\n{config}")
E985 = ("Can't load model from config file: no 'nlp' section found.\n\n{config}") E985 = ("Can't load model from config file: no 'nlp' section found.\n\n{config}")

View File

@ -185,7 +185,7 @@ class Language:
for pipe_name in self.pipe_names: for pipe_name in self.pipe_names:
pipe_meta = self.get_pipe_meta(pipe_name) pipe_meta = self.get_pipe_meta(pipe_name)
pipe_config = self.get_pipe_config(pipe_name) pipe_config = self.get_pipe_config(pipe_name)
pipeline[pipe_name] = {"@factories": pipe_meta.factory, **pipe_config} pipeline[pipe_name] = {"factory": pipe_meta.factory, **pipe_config}
self._config["nlp"]["pipeline"] = self.pipe_names self._config["nlp"]["pipeline"] = self.pipe_names
self._config["components"] = pipeline self._config["components"] = pipeline
if not srsly.is_json_serializable(self._config): if not srsly.is_json_serializable(self._config):
@ -491,7 +491,7 @@ class Language:
# pipeline component and why it failed, explain default config # pipeline component and why it failed, explain default config
resolved, filled = registry.resolve(cfg, validate=validate, overrides=overrides) resolved, filled = registry.resolve(cfg, validate=validate, overrides=overrides)
filled = filled[factory_name] filled = filled[factory_name]
filled["@factories"] = factory_name filled["factory"] = factory_name
self._pipe_configs[name] = filled self._pipe_configs[name] = filled
return resolved[factory_name] return resolved[factory_name]
@ -1283,12 +1283,12 @@ class Language:
if pipe_name not in pipeline: if pipe_name not in pipeline:
opts = ", ".join(pipeline.keys()) opts = ", ".join(pipeline.keys())
raise ValueError(Errors.E956.format(name=pipe_name, opts=opts)) raise ValueError(Errors.E956.format(name=pipe_name, opts=opts))
pipe_cfg = pipeline[pipe_name] pipe_cfg = util.copy_config(pipeline[pipe_name])
if pipe_name not in disable: if pipe_name not in disable:
if "@factories" not in pipe_cfg: if "factory" not in pipe_cfg:
err = Errors.E984.format(name=pipe_name, config=pipe_cfg) err = Errors.E984.format(name=pipe_name, config=pipe_cfg)
raise ValueError(err) raise ValueError(err)
factory = pipe_cfg["@factories"] factory = pipe_cfg.pop("factory")
# The pipe name (key in the config) here is the unique name of the # The pipe name (key in the config) here is the unique name of the
# component, not necessarily the factory # component, not necessarily the factory
nlp.add_pipe( nlp.add_pipe(

View File

@ -20,7 +20,7 @@ pipeline = ["tok2vec", "tagger"]
[components] [components]
[components.tok2vec] [components.tok2vec]
@factories = "tok2vec" factory = "tok2vec"
[components.tok2vec.model] [components.tok2vec.model]
@architectures = "spacy.HashEmbedCNN.v1" @architectures = "spacy.HashEmbedCNN.v1"
@ -34,7 +34,7 @@ subword_features = true
dropout = null dropout = null
[components.tagger] [components.tagger]
@factories = "tagger" factory = "tagger"
[components.tagger.model] [components.tagger.model]
@architectures = "spacy.Tagger.v1" @architectures = "spacy.Tagger.v1"
@ -245,7 +245,7 @@ def test_serialize_config_language_specific():
nlp.add_pipe(name, config={"foo": 100}, name="bar") nlp.add_pipe(name, config={"foo": 100}, name="bar")
pipe_config = nlp.config["components"]["bar"] pipe_config = nlp.config["components"]["bar"]
assert pipe_config["foo"] == 100 assert pipe_config["foo"] == 100
assert pipe_config["@factories"] == name assert pipe_config["factory"] == name
with make_tempdir() as d: with make_tempdir() as d:
nlp.to_disk(d) nlp.to_disk(d)
@ -255,7 +255,7 @@ def test_serialize_config_language_specific():
assert nlp2.get_pipe_meta("bar").factory == name assert nlp2.get_pipe_meta("bar").factory == name
pipe_config = nlp2.config["components"]["bar"] pipe_config = nlp2.config["components"]["bar"]
assert pipe_config["foo"] == 100 assert pipe_config["foo"] == 100
assert pipe_config["@factories"] == name assert pipe_config["factory"] == name
config = Config().from_str(nlp2.config.to_str()) config = Config().from_str(nlp2.config.to_str())
config["nlp"]["lang"] = "de" config["nlp"]["lang"] = "de"