mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
Merge pull request #5824 from svlandeg/fix/textcat-v3
This commit is contained in:
commit
b83ead5bf5
|
@ -20,6 +20,7 @@ import spacy
|
|||
from spacy import util
|
||||
from spacy.util import minibatch, compounding
|
||||
from spacy.gold import Example
|
||||
from thinc.api import Config
|
||||
|
||||
|
||||
@plac.annotations(
|
||||
|
@ -42,8 +43,8 @@ def main(config_path, output_dir=None, n_iter=20, n_texts=2000, init_tok2vec=Non
|
|||
output_dir.mkdir()
|
||||
|
||||
print(f"Loading nlp model from {config_path}")
|
||||
nlp_config = util.load_config(config_path, create_objects=False)["nlp"]
|
||||
nlp = util.load_model_from_config(nlp_config)
|
||||
nlp_config = Config().from_disk(config_path)
|
||||
nlp, _ = util.load_model_from_config(nlp_config, auto_fill=True)
|
||||
|
||||
# ensure the nlp object was defined with a textcat component
|
||||
if "textcat" not in nlp.pipe_names:
|
||||
|
|
|
@ -1,19 +1,14 @@
|
|||
[nlp]
|
||||
lang = "en"
|
||||
pipeline = ["textcat"]
|
||||
|
||||
[nlp.pipeline.textcat]
|
||||
[components]
|
||||
|
||||
[components.textcat]
|
||||
factory = "textcat"
|
||||
|
||||
[nlp.pipeline.textcat.model]
|
||||
@architectures = "spacy.TextCatCNN.v1"
|
||||
exclusive_classes = false
|
||||
|
||||
[nlp.pipeline.textcat.model.tok2vec]
|
||||
@architectures = "spacy.HashEmbedCNN.v1"
|
||||
pretrained_vectors = null
|
||||
width = 96
|
||||
depth = 4
|
||||
embed_size = 2000
|
||||
window_size = 1
|
||||
maxout_pieces = 3
|
||||
subword_features = true
|
||||
[components.textcat.model]
|
||||
@architectures = "spacy.TextCatBOW.v1"
|
||||
exclusive_classes = true
|
||||
ngram_size = 1
|
||||
no_output_layer = false
|
||||
|
|
|
@ -8,6 +8,7 @@ import typer
|
|||
from ._util import Arg, Opt, debug_cli, show_validation_error, parse_config_overrides
|
||||
from .. import util
|
||||
from ..lang.en import English
|
||||
from ..util import dot_to_object
|
||||
|
||||
|
||||
@debug_cli.command("model")
|
||||
|
@ -60,16 +61,7 @@ def debug_model_cli(
|
|||
msg.info(f"Fixing random seed: {seed}")
|
||||
fix_random_seed(seed)
|
||||
|
||||
component = config
|
||||
parts = section.split(".")
|
||||
for item in parts:
|
||||
try:
|
||||
component = component[item]
|
||||
except KeyError:
|
||||
msg.fail(
|
||||
f"The section '{section}' is not a valid section in the provided config.",
|
||||
exits=1,
|
||||
)
|
||||
component = dot_to_object(config, section)
|
||||
if hasattr(component, "model"):
|
||||
model = component.model
|
||||
else:
|
||||
|
|
|
@ -592,7 +592,7 @@ class Errors:
|
|||
"for the `nlp` pipeline with components {names}.")
|
||||
E993 = ("The config for 'nlp' needs to include a key 'lang' specifying "
|
||||
"the code of the language to initialize it with (for example "
|
||||
"'en' for English).\n\n{config}")
|
||||
"'en' for English) - this can't be 'None'.\n\n{config}")
|
||||
E996 = ("Could not parse {file}: {msg}")
|
||||
E997 = ("Tokenizer special cases are not allowed to modify the text. "
|
||||
"This would map '{chunk}' to '{orth}' given token attributes "
|
||||
|
|
|
@ -2,7 +2,13 @@ import pytest
|
|||
|
||||
from .util import get_random_doc
|
||||
|
||||
from spacy.util import minibatch_by_words
|
||||
from spacy import util
|
||||
from spacy.util import minibatch_by_words, dot_to_object
|
||||
from thinc.api import Config, Optimizer
|
||||
|
||||
from ..lang.en import English
|
||||
from ..lang.nl import Dutch
|
||||
from ..language import DEFAULT_CONFIG_PATH
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -56,3 +62,49 @@ def test_util_minibatch_oversize(doc_sizes, expected_batches):
|
|||
minibatch_by_words(docs, size=batch_size, tolerance=tol, discard_oversize=False)
|
||||
)
|
||||
assert [len(batch) for batch in batches] == expected_batches
|
||||
|
||||
|
||||
def test_util_dot_section():
|
||||
cfg_string = """
|
||||
[nlp]
|
||||
lang = "en"
|
||||
pipeline = ["textcat"]
|
||||
load_vocab_data = false
|
||||
|
||||
[components]
|
||||
|
||||
[components.textcat]
|
||||
factory = "textcat"
|
||||
|
||||
[components.textcat.model]
|
||||
@architectures = "spacy.TextCatBOW.v1"
|
||||
exclusive_classes = true
|
||||
ngram_size = 1
|
||||
no_output_layer = false
|
||||
"""
|
||||
nlp_config = Config().from_str(cfg_string)
|
||||
en_nlp, en_config = util.load_model_from_config(nlp_config, auto_fill=True)
|
||||
|
||||
default_config = Config().from_disk(DEFAULT_CONFIG_PATH)
|
||||
default_config["nlp"]["lang"] = "nl"
|
||||
nl_nlp, nl_config = util.load_model_from_config(default_config, auto_fill=True)
|
||||
|
||||
# Test that creation went OK
|
||||
assert isinstance(en_nlp, English)
|
||||
assert isinstance(nl_nlp, Dutch)
|
||||
assert nl_nlp.pipe_names == []
|
||||
assert en_nlp.pipe_names == ["textcat"]
|
||||
assert en_nlp.get_pipe("textcat").model.attrs["multi_label"] == False # not exclusive_classes
|
||||
|
||||
# Test that default values got overwritten
|
||||
assert not en_config["nlp"]["load_vocab_data"]
|
||||
assert nl_config["nlp"]["load_vocab_data"] # default value True
|
||||
|
||||
# Test proper functioning of 'dot_to_object'
|
||||
with pytest.raises(KeyError):
|
||||
obj = dot_to_object(en_config, "nlp.pipeline.tagger")
|
||||
with pytest.raises(KeyError):
|
||||
obj = dot_to_object(en_config, "nlp.unknownattribute")
|
||||
assert not dot_to_object(en_config, "nlp.load_vocab_data")
|
||||
assert dot_to_object(nl_config, "nlp.load_vocab_data")
|
||||
assert isinstance(dot_to_object(nl_config, "training.optimizer"), Optimizer)
|
||||
|
|
|
@ -258,7 +258,7 @@ def load_model_from_config(
|
|||
if "nlp" not in config:
|
||||
raise ValueError(Errors.E985.format(config=config))
|
||||
nlp_config = config["nlp"]
|
||||
if "lang" not in nlp_config:
|
||||
if "lang" not in nlp_config or nlp_config["lang"] is None:
|
||||
raise ValueError(Errors.E993.format(config=nlp_config))
|
||||
# This will automatically handle all codes registered via the languages
|
||||
# registry, including custom subclasses provided via entry points
|
||||
|
@ -1107,6 +1107,26 @@ def dict_to_dot(obj: Dict[str, dict]) -> Dict[str, Any]:
|
|||
return {".".join(key): value for key, value in walk_dict(obj)}
|
||||
|
||||
|
||||
def dot_to_object(config: Config, section: str):
|
||||
"""Convert dot notation of a "section" to a specific part of the Config.
|
||||
e.g. "training.optimizer" would return the Optimizer object.
|
||||
Throws an error if the section is not defined in this config.
|
||||
|
||||
config (Config): The config.
|
||||
section (str): The dot notation of the section in the config.
|
||||
RETURNS: The object denoted by the section
|
||||
"""
|
||||
component = config
|
||||
parts = section.split(".")
|
||||
for item in parts:
|
||||
try:
|
||||
component = component[item]
|
||||
except (KeyError, TypeError) as e:
|
||||
msg = f"The section '{section}' is not a valid section in the provided config."
|
||||
raise KeyError(msg)
|
||||
return component
|
||||
|
||||
|
||||
def walk_dict(
|
||||
node: Dict[str, Any], parent: List[str] = []
|
||||
) -> Iterator[Tuple[List[str], Any]]:
|
||||
|
|
Loading…
Reference in New Issue
Block a user