Merge pull request #5824 from svlandeg/fix/textcat-v3

This commit is contained in:
Ines Montani 2020-07-28 15:04:25 +02:00 committed by GitHub
commit b83ead5bf5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 89 additions and 29 deletions

View File

@ -20,6 +20,7 @@ import spacy
from spacy import util from spacy import util
from spacy.util import minibatch, compounding from spacy.util import minibatch, compounding
from spacy.gold import Example from spacy.gold import Example
from thinc.api import Config
@plac.annotations( @plac.annotations(
@ -42,8 +43,8 @@ def main(config_path, output_dir=None, n_iter=20, n_texts=2000, init_tok2vec=Non
output_dir.mkdir() output_dir.mkdir()
print(f"Loading nlp model from {config_path}") print(f"Loading nlp model from {config_path}")
nlp_config = util.load_config(config_path, create_objects=False)["nlp"] nlp_config = Config().from_disk(config_path)
nlp = util.load_model_from_config(nlp_config) nlp, _ = util.load_model_from_config(nlp_config, auto_fill=True)
# ensure the nlp object was defined with a textcat component # ensure the nlp object was defined with a textcat component
if "textcat" not in nlp.pipe_names: if "textcat" not in nlp.pipe_names:

View File

@ -1,19 +1,14 @@
[nlp] [nlp]
lang = "en" lang = "en"
pipeline = ["textcat"]
[nlp.pipeline.textcat] [components]
[components.textcat]
factory = "textcat" factory = "textcat"
[nlp.pipeline.textcat.model] [components.textcat.model]
@architectures = "spacy.TextCatCNN.v1" @architectures = "spacy.TextCatBOW.v1"
exclusive_classes = false exclusive_classes = true
ngram_size = 1
[nlp.pipeline.textcat.model.tok2vec] no_output_layer = false
@architectures = "spacy.HashEmbedCNN.v1"
pretrained_vectors = null
width = 96
depth = 4
embed_size = 2000
window_size = 1
maxout_pieces = 3
subword_features = true

View File

@ -8,6 +8,7 @@ import typer
from ._util import Arg, Opt, debug_cli, show_validation_error, parse_config_overrides from ._util import Arg, Opt, debug_cli, show_validation_error, parse_config_overrides
from .. import util from .. import util
from ..lang.en import English from ..lang.en import English
from ..util import dot_to_object
@debug_cli.command("model") @debug_cli.command("model")
@ -60,16 +61,7 @@ def debug_model_cli(
msg.info(f"Fixing random seed: {seed}") msg.info(f"Fixing random seed: {seed}")
fix_random_seed(seed) fix_random_seed(seed)
component = config component = dot_to_object(config, section)
parts = section.split(".")
for item in parts:
try:
component = component[item]
except KeyError:
msg.fail(
f"The section '{section}' is not a valid section in the provided config.",
exits=1,
)
if hasattr(component, "model"): if hasattr(component, "model"):
model = component.model model = component.model
else: else:

View File

@ -592,7 +592,7 @@ class Errors:
"for the `nlp` pipeline with components {names}.") "for the `nlp` pipeline with components {names}.")
E993 = ("The config for 'nlp' needs to include a key 'lang' specifying " E993 = ("The config for 'nlp' needs to include a key 'lang' specifying "
"the code of the language to initialize it with (for example " "the code of the language to initialize it with (for example "
"'en' for English).\n\n{config}") "'en' for English) - this can't be 'None'.\n\n{config}")
E996 = ("Could not parse {file}: {msg}") E996 = ("Could not parse {file}: {msg}")
E997 = ("Tokenizer special cases are not allowed to modify the text. " E997 = ("Tokenizer special cases are not allowed to modify the text. "
"This would map '{chunk}' to '{orth}' given token attributes " "This would map '{chunk}' to '{orth}' given token attributes "

View File

@ -2,7 +2,13 @@ import pytest
from .util import get_random_doc from .util import get_random_doc
from spacy.util import minibatch_by_words from spacy import util
from spacy.util import minibatch_by_words, dot_to_object
from thinc.api import Config, Optimizer
from ..lang.en import English
from ..lang.nl import Dutch
from ..language import DEFAULT_CONFIG_PATH
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -56,3 +62,49 @@ def test_util_minibatch_oversize(doc_sizes, expected_batches):
minibatch_by_words(docs, size=batch_size, tolerance=tol, discard_oversize=False) minibatch_by_words(docs, size=batch_size, tolerance=tol, discard_oversize=False)
) )
assert [len(batch) for batch in batches] == expected_batches assert [len(batch) for batch in batches] == expected_batches
def test_util_dot_section():
cfg_string = """
[nlp]
lang = "en"
pipeline = ["textcat"]
load_vocab_data = false
[components]
[components.textcat]
factory = "textcat"
[components.textcat.model]
@architectures = "spacy.TextCatBOW.v1"
exclusive_classes = true
ngram_size = 1
no_output_layer = false
"""
nlp_config = Config().from_str(cfg_string)
en_nlp, en_config = util.load_model_from_config(nlp_config, auto_fill=True)
default_config = Config().from_disk(DEFAULT_CONFIG_PATH)
default_config["nlp"]["lang"] = "nl"
nl_nlp, nl_config = util.load_model_from_config(default_config, auto_fill=True)
# Test that creation went OK
assert isinstance(en_nlp, English)
assert isinstance(nl_nlp, Dutch)
assert nl_nlp.pipe_names == []
assert en_nlp.pipe_names == ["textcat"]
assert en_nlp.get_pipe("textcat").model.attrs["multi_label"] == False # not exclusive_classes
# Test that default values got overwritten
assert not en_config["nlp"]["load_vocab_data"]
assert nl_config["nlp"]["load_vocab_data"] # default value True
# Test proper functioning of 'dot_to_object'
with pytest.raises(KeyError):
obj = dot_to_object(en_config, "nlp.pipeline.tagger")
with pytest.raises(KeyError):
obj = dot_to_object(en_config, "nlp.unknownattribute")
assert not dot_to_object(en_config, "nlp.load_vocab_data")
assert dot_to_object(nl_config, "nlp.load_vocab_data")
assert isinstance(dot_to_object(nl_config, "training.optimizer"), Optimizer)

View File

@ -258,7 +258,7 @@ def load_model_from_config(
if "nlp" not in config: if "nlp" not in config:
raise ValueError(Errors.E985.format(config=config)) raise ValueError(Errors.E985.format(config=config))
nlp_config = config["nlp"] nlp_config = config["nlp"]
if "lang" not in nlp_config: if "lang" not in nlp_config or nlp_config["lang"] is None:
raise ValueError(Errors.E993.format(config=nlp_config)) raise ValueError(Errors.E993.format(config=nlp_config))
# This will automatically handle all codes registered via the languages # This will automatically handle all codes registered via the languages
# registry, including custom subclasses provided via entry points # registry, including custom subclasses provided via entry points
@ -1107,6 +1107,26 @@ def dict_to_dot(obj: Dict[str, dict]) -> Dict[str, Any]:
return {".".join(key): value for key, value in walk_dict(obj)} return {".".join(key): value for key, value in walk_dict(obj)}
def dot_to_object(config: Config, section: str):
"""Convert dot notation of a "section" to a specific part of the Config.
e.g. "training.optimizer" would return the Optimizer object.
Throws an error if the section is not defined in this config.
config (Config): The config.
section (str): The dot notation of the section in the config.
RETURNS: The object denoted by the section
"""
component = config
parts = section.split(".")
for item in parts:
try:
component = component[item]
except (KeyError, TypeError) as e:
msg = f"The section '{section}' is not a valid section in the provided config."
raise KeyError(msg)
return component
def walk_dict( def walk_dict(
node: Dict[str, Any], parent: List[str] = [] node: Dict[str, Any], parent: List[str] = []
) -> Iterator[Tuple[List[str], Any]]: ) -> Iterator[Tuple[List[str], Any]]: