mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 09:56:28 +03:00
Merge pull request #5824 from svlandeg/fix/textcat-v3
This commit is contained in:
commit
b83ead5bf5
|
@ -20,6 +20,7 @@ import spacy
|
||||||
from spacy import util
|
from spacy import util
|
||||||
from spacy.util import minibatch, compounding
|
from spacy.util import minibatch, compounding
|
||||||
from spacy.gold import Example
|
from spacy.gold import Example
|
||||||
|
from thinc.api import Config
|
||||||
|
|
||||||
|
|
||||||
@plac.annotations(
|
@plac.annotations(
|
||||||
|
@ -42,8 +43,8 @@ def main(config_path, output_dir=None, n_iter=20, n_texts=2000, init_tok2vec=Non
|
||||||
output_dir.mkdir()
|
output_dir.mkdir()
|
||||||
|
|
||||||
print(f"Loading nlp model from {config_path}")
|
print(f"Loading nlp model from {config_path}")
|
||||||
nlp_config = util.load_config(config_path, create_objects=False)["nlp"]
|
nlp_config = Config().from_disk(config_path)
|
||||||
nlp = util.load_model_from_config(nlp_config)
|
nlp, _ = util.load_model_from_config(nlp_config, auto_fill=True)
|
||||||
|
|
||||||
# ensure the nlp object was defined with a textcat component
|
# ensure the nlp object was defined with a textcat component
|
||||||
if "textcat" not in nlp.pipe_names:
|
if "textcat" not in nlp.pipe_names:
|
||||||
|
|
|
@ -1,19 +1,14 @@
|
||||||
[nlp]
|
[nlp]
|
||||||
lang = "en"
|
lang = "en"
|
||||||
|
pipeline = ["textcat"]
|
||||||
|
|
||||||
[nlp.pipeline.textcat]
|
[components]
|
||||||
|
|
||||||
|
[components.textcat]
|
||||||
factory = "textcat"
|
factory = "textcat"
|
||||||
|
|
||||||
[nlp.pipeline.textcat.model]
|
[components.textcat.model]
|
||||||
@architectures = "spacy.TextCatCNN.v1"
|
@architectures = "spacy.TextCatBOW.v1"
|
||||||
exclusive_classes = false
|
exclusive_classes = true
|
||||||
|
ngram_size = 1
|
||||||
[nlp.pipeline.textcat.model.tok2vec]
|
no_output_layer = false
|
||||||
@architectures = "spacy.HashEmbedCNN.v1"
|
|
||||||
pretrained_vectors = null
|
|
||||||
width = 96
|
|
||||||
depth = 4
|
|
||||||
embed_size = 2000
|
|
||||||
window_size = 1
|
|
||||||
maxout_pieces = 3
|
|
||||||
subword_features = true
|
|
||||||
|
|
|
@ -8,6 +8,7 @@ import typer
|
||||||
from ._util import Arg, Opt, debug_cli, show_validation_error, parse_config_overrides
|
from ._util import Arg, Opt, debug_cli, show_validation_error, parse_config_overrides
|
||||||
from .. import util
|
from .. import util
|
||||||
from ..lang.en import English
|
from ..lang.en import English
|
||||||
|
from ..util import dot_to_object
|
||||||
|
|
||||||
|
|
||||||
@debug_cli.command("model")
|
@debug_cli.command("model")
|
||||||
|
@ -60,16 +61,7 @@ def debug_model_cli(
|
||||||
msg.info(f"Fixing random seed: {seed}")
|
msg.info(f"Fixing random seed: {seed}")
|
||||||
fix_random_seed(seed)
|
fix_random_seed(seed)
|
||||||
|
|
||||||
component = config
|
component = dot_to_object(config, section)
|
||||||
parts = section.split(".")
|
|
||||||
for item in parts:
|
|
||||||
try:
|
|
||||||
component = component[item]
|
|
||||||
except KeyError:
|
|
||||||
msg.fail(
|
|
||||||
f"The section '{section}' is not a valid section in the provided config.",
|
|
||||||
exits=1,
|
|
||||||
)
|
|
||||||
if hasattr(component, "model"):
|
if hasattr(component, "model"):
|
||||||
model = component.model
|
model = component.model
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -592,7 +592,7 @@ class Errors:
|
||||||
"for the `nlp` pipeline with components {names}.")
|
"for the `nlp` pipeline with components {names}.")
|
||||||
E993 = ("The config for 'nlp' needs to include a key 'lang' specifying "
|
E993 = ("The config for 'nlp' needs to include a key 'lang' specifying "
|
||||||
"the code of the language to initialize it with (for example "
|
"the code of the language to initialize it with (for example "
|
||||||
"'en' for English).\n\n{config}")
|
"'en' for English) - this can't be 'None'.\n\n{config}")
|
||||||
E996 = ("Could not parse {file}: {msg}")
|
E996 = ("Could not parse {file}: {msg}")
|
||||||
E997 = ("Tokenizer special cases are not allowed to modify the text. "
|
E997 = ("Tokenizer special cases are not allowed to modify the text. "
|
||||||
"This would map '{chunk}' to '{orth}' given token attributes "
|
"This would map '{chunk}' to '{orth}' given token attributes "
|
||||||
|
|
|
@ -2,7 +2,13 @@ import pytest
|
||||||
|
|
||||||
from .util import get_random_doc
|
from .util import get_random_doc
|
||||||
|
|
||||||
from spacy.util import minibatch_by_words
|
from spacy import util
|
||||||
|
from spacy.util import minibatch_by_words, dot_to_object
|
||||||
|
from thinc.api import Config, Optimizer
|
||||||
|
|
||||||
|
from ..lang.en import English
|
||||||
|
from ..lang.nl import Dutch
|
||||||
|
from ..language import DEFAULT_CONFIG_PATH
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -56,3 +62,49 @@ def test_util_minibatch_oversize(doc_sizes, expected_batches):
|
||||||
minibatch_by_words(docs, size=batch_size, tolerance=tol, discard_oversize=False)
|
minibatch_by_words(docs, size=batch_size, tolerance=tol, discard_oversize=False)
|
||||||
)
|
)
|
||||||
assert [len(batch) for batch in batches] == expected_batches
|
assert [len(batch) for batch in batches] == expected_batches
|
||||||
|
|
||||||
|
|
||||||
|
def test_util_dot_section():
|
||||||
|
cfg_string = """
|
||||||
|
[nlp]
|
||||||
|
lang = "en"
|
||||||
|
pipeline = ["textcat"]
|
||||||
|
load_vocab_data = false
|
||||||
|
|
||||||
|
[components]
|
||||||
|
|
||||||
|
[components.textcat]
|
||||||
|
factory = "textcat"
|
||||||
|
|
||||||
|
[components.textcat.model]
|
||||||
|
@architectures = "spacy.TextCatBOW.v1"
|
||||||
|
exclusive_classes = true
|
||||||
|
ngram_size = 1
|
||||||
|
no_output_layer = false
|
||||||
|
"""
|
||||||
|
nlp_config = Config().from_str(cfg_string)
|
||||||
|
en_nlp, en_config = util.load_model_from_config(nlp_config, auto_fill=True)
|
||||||
|
|
||||||
|
default_config = Config().from_disk(DEFAULT_CONFIG_PATH)
|
||||||
|
default_config["nlp"]["lang"] = "nl"
|
||||||
|
nl_nlp, nl_config = util.load_model_from_config(default_config, auto_fill=True)
|
||||||
|
|
||||||
|
# Test that creation went OK
|
||||||
|
assert isinstance(en_nlp, English)
|
||||||
|
assert isinstance(nl_nlp, Dutch)
|
||||||
|
assert nl_nlp.pipe_names == []
|
||||||
|
assert en_nlp.pipe_names == ["textcat"]
|
||||||
|
assert en_nlp.get_pipe("textcat").model.attrs["multi_label"] == False # not exclusive_classes
|
||||||
|
|
||||||
|
# Test that default values got overwritten
|
||||||
|
assert not en_config["nlp"]["load_vocab_data"]
|
||||||
|
assert nl_config["nlp"]["load_vocab_data"] # default value True
|
||||||
|
|
||||||
|
# Test proper functioning of 'dot_to_object'
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
obj = dot_to_object(en_config, "nlp.pipeline.tagger")
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
obj = dot_to_object(en_config, "nlp.unknownattribute")
|
||||||
|
assert not dot_to_object(en_config, "nlp.load_vocab_data")
|
||||||
|
assert dot_to_object(nl_config, "nlp.load_vocab_data")
|
||||||
|
assert isinstance(dot_to_object(nl_config, "training.optimizer"), Optimizer)
|
||||||
|
|
|
@ -258,7 +258,7 @@ def load_model_from_config(
|
||||||
if "nlp" not in config:
|
if "nlp" not in config:
|
||||||
raise ValueError(Errors.E985.format(config=config))
|
raise ValueError(Errors.E985.format(config=config))
|
||||||
nlp_config = config["nlp"]
|
nlp_config = config["nlp"]
|
||||||
if "lang" not in nlp_config:
|
if "lang" not in nlp_config or nlp_config["lang"] is None:
|
||||||
raise ValueError(Errors.E993.format(config=nlp_config))
|
raise ValueError(Errors.E993.format(config=nlp_config))
|
||||||
# This will automatically handle all codes registered via the languages
|
# This will automatically handle all codes registered via the languages
|
||||||
# registry, including custom subclasses provided via entry points
|
# registry, including custom subclasses provided via entry points
|
||||||
|
@ -1107,6 +1107,26 @@ def dict_to_dot(obj: Dict[str, dict]) -> Dict[str, Any]:
|
||||||
return {".".join(key): value for key, value in walk_dict(obj)}
|
return {".".join(key): value for key, value in walk_dict(obj)}
|
||||||
|
|
||||||
|
|
||||||
|
def dot_to_object(config: Config, section: str):
|
||||||
|
"""Convert dot notation of a "section" to a specific part of the Config.
|
||||||
|
e.g. "training.optimizer" would return the Optimizer object.
|
||||||
|
Throws an error if the section is not defined in this config.
|
||||||
|
|
||||||
|
config (Config): The config.
|
||||||
|
section (str): The dot notation of the section in the config.
|
||||||
|
RETURNS: The object denoted by the section
|
||||||
|
"""
|
||||||
|
component = config
|
||||||
|
parts = section.split(".")
|
||||||
|
for item in parts:
|
||||||
|
try:
|
||||||
|
component = component[item]
|
||||||
|
except (KeyError, TypeError) as e:
|
||||||
|
msg = f"The section '{section}' is not a valid section in the provided config."
|
||||||
|
raise KeyError(msg)
|
||||||
|
return component
|
||||||
|
|
||||||
|
|
||||||
def walk_dict(
|
def walk_dict(
|
||||||
node: Dict[str, Any], parent: List[str] = []
|
node: Dict[str, Any], parent: List[str] = []
|
||||||
) -> Iterator[Tuple[List[str], Any]]:
|
) -> Iterator[Tuple[List[str], Any]]:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user