Merge pull request #5824 from svlandeg/fix/textcat-v3

This commit is contained in:
Ines Montani 2020-07-28 15:04:25 +02:00 committed by GitHub
commit b83ead5bf5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 89 additions and 29 deletions

View File

@ -20,6 +20,7 @@ import spacy
from spacy import util
from spacy.util import minibatch, compounding
from spacy.gold import Example
from thinc.api import Config
@plac.annotations(
@ -42,8 +43,8 @@ def main(config_path, output_dir=None, n_iter=20, n_texts=2000, init_tok2vec=Non
output_dir.mkdir()
print(f"Loading nlp model from {config_path}")
nlp_config = util.load_config(config_path, create_objects=False)["nlp"]
nlp = util.load_model_from_config(nlp_config)
nlp_config = Config().from_disk(config_path)
nlp, _ = util.load_model_from_config(nlp_config, auto_fill=True)
# ensure the nlp object was defined with a textcat component
if "textcat" not in nlp.pipe_names:

View File

@ -1,19 +1,14 @@
[nlp]
lang = "en"
pipeline = ["textcat"]
[nlp.pipeline.textcat]
[components]
[components.textcat]
factory = "textcat"
[nlp.pipeline.textcat.model]
@architectures = "spacy.TextCatCNN.v1"
exclusive_classes = false
[nlp.pipeline.textcat.model.tok2vec]
@architectures = "spacy.HashEmbedCNN.v1"
pretrained_vectors = null
width = 96
depth = 4
embed_size = 2000
window_size = 1
maxout_pieces = 3
subword_features = true
[components.textcat.model]
@architectures = "spacy.TextCatBOW.v1"
exclusive_classes = true
ngram_size = 1
no_output_layer = false

View File

@ -8,6 +8,7 @@ import typer
from ._util import Arg, Opt, debug_cli, show_validation_error, parse_config_overrides
from .. import util
from ..lang.en import English
from ..util import dot_to_object
@debug_cli.command("model")
@ -60,16 +61,7 @@ def debug_model_cli(
msg.info(f"Fixing random seed: {seed}")
fix_random_seed(seed)
component = config
parts = section.split(".")
for item in parts:
try:
component = component[item]
except KeyError:
msg.fail(
f"The section '{section}' is not a valid section in the provided config.",
exits=1,
)
component = dot_to_object(config, section)
if hasattr(component, "model"):
model = component.model
else:

View File

@ -592,7 +592,7 @@ class Errors:
"for the `nlp` pipeline with components {names}.")
E993 = ("The config for 'nlp' needs to include a key 'lang' specifying "
"the code of the language to initialize it with (for example "
"'en' for English).\n\n{config}")
"'en' for English) - this can't be 'None'.\n\n{config}")
E996 = ("Could not parse {file}: {msg}")
E997 = ("Tokenizer special cases are not allowed to modify the text. "
"This would map '{chunk}' to '{orth}' given token attributes "

View File

@ -2,7 +2,13 @@ import pytest
from .util import get_random_doc
from spacy.util import minibatch_by_words
from spacy import util
from spacy.util import minibatch_by_words, dot_to_object
from thinc.api import Config, Optimizer
from ..lang.en import English
from ..lang.nl import Dutch
from ..language import DEFAULT_CONFIG_PATH
@pytest.mark.parametrize(
@ -56,3 +62,49 @@ def test_util_minibatch_oversize(doc_sizes, expected_batches):
minibatch_by_words(docs, size=batch_size, tolerance=tol, discard_oversize=False)
)
assert [len(batch) for batch in batches] == expected_batches
def test_util_dot_section():
cfg_string = """
[nlp]
lang = "en"
pipeline = ["textcat"]
load_vocab_data = false
[components]
[components.textcat]
factory = "textcat"
[components.textcat.model]
@architectures = "spacy.TextCatBOW.v1"
exclusive_classes = true
ngram_size = 1
no_output_layer = false
"""
nlp_config = Config().from_str(cfg_string)
en_nlp, en_config = util.load_model_from_config(nlp_config, auto_fill=True)
default_config = Config().from_disk(DEFAULT_CONFIG_PATH)
default_config["nlp"]["lang"] = "nl"
nl_nlp, nl_config = util.load_model_from_config(default_config, auto_fill=True)
# Test that creation went OK
assert isinstance(en_nlp, English)
assert isinstance(nl_nlp, Dutch)
assert nl_nlp.pipe_names == []
assert en_nlp.pipe_names == ["textcat"]
assert en_nlp.get_pipe("textcat").model.attrs["multi_label"] == False # not exclusive_classes
# Test that default values got overwritten
assert not en_config["nlp"]["load_vocab_data"]
assert nl_config["nlp"]["load_vocab_data"] # default value True
# Test proper functioning of 'dot_to_object'
with pytest.raises(KeyError):
obj = dot_to_object(en_config, "nlp.pipeline.tagger")
with pytest.raises(KeyError):
obj = dot_to_object(en_config, "nlp.unknownattribute")
assert not dot_to_object(en_config, "nlp.load_vocab_data")
assert dot_to_object(nl_config, "nlp.load_vocab_data")
assert isinstance(dot_to_object(nl_config, "training.optimizer"), Optimizer)

View File

@ -258,7 +258,7 @@ def load_model_from_config(
if "nlp" not in config:
raise ValueError(Errors.E985.format(config=config))
nlp_config = config["nlp"]
if "lang" not in nlp_config:
if "lang" not in nlp_config or nlp_config["lang"] is None:
raise ValueError(Errors.E993.format(config=nlp_config))
# This will automatically handle all codes registered via the languages
# registry, including custom subclasses provided via entry points
@ -1107,6 +1107,26 @@ def dict_to_dot(obj: Dict[str, dict]) -> Dict[str, Any]:
return {".".join(key): value for key, value in walk_dict(obj)}
def dot_to_object(config: Config, section: str):
"""Convert dot notation of a "section" to a specific part of the Config.
e.g. "training.optimizer" would return the Optimizer object.
Throws an error if the section is not defined in this config.
config (Config): The config.
section (str): The dot notation of the section in the config.
RETURNS: The object denoted by the section
"""
component = config
parts = section.split(".")
for item in parts:
try:
component = component[item]
except (KeyError, TypeError) as e:
msg = f"The section '{section}' is not a valid section in the provided config."
raise KeyError(msg)
return component
def walk_dict(
node: Dict[str, Any], parent: List[str] = []
) -> Iterator[Tuple[List[str], Any]]: