Merge pull request #6168 from explosion/fix/default-corpus-values

This commit is contained in:
Ines Montani 2020-09-30 00:24:02 +02:00 committed by GitHub
commit fe3f111c37
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 15 additions and 11 deletions

View File

@ -4,8 +4,8 @@ can help generate the best possible configuration, given a user's requirements.
{%- set use_transformer = (transformer_data and hardware != "cpu") -%} {%- set use_transformer = (transformer_data and hardware != "cpu") -%}
{%- set transformer = transformer_data[optimize] if use_transformer else {} -%} {%- set transformer = transformer_data[optimize] if use_transformer else {} -%}
[paths] [paths]
train = "" train = null
dev = "" dev = null
[system] [system]
{% if use_transformer -%} {% if use_transformer -%}

View File

@ -1,6 +1,6 @@
[paths] [paths]
train = "" train = null
dev = "" dev = null
vectors = null vectors = null
vocab_data = null vocab_data = null
init_tok2vec = null init_tok2vec = null

View File

@ -477,6 +477,8 @@ class Errors:
E201 = ("Span index out of range.") E201 = ("Span index out of range.")
# TODO: fix numbering after merging develop into master # TODO: fix numbering after merging develop into master
E913 = ("Corpus path can't be None. Maybe you forgot to define it in your "
"config.cfg or override it on the CLI?")
E914 = ("Executing {name} callback failed. Expected the function to " E914 = ("Executing {name} callback failed. Expected the function to "
"return the nlp object but got: {value}. Maybe you forgot to return " "return the nlp object but got: {value}. Maybe you forgot to return "
"the modified object in your function?") "the modified object in your function?")

View File

@ -14,8 +14,8 @@ from ..util import make_tempdir
nlp_config_string = """ nlp_config_string = """
[paths] [paths]
train = "" train = null
dev = "" dev = null
[corpora] [corpora]
@ -309,7 +309,7 @@ def test_config_interpolation():
config = Config().from_str(nlp_config_string, interpolate=False) config = Config().from_str(nlp_config_string, interpolate=False)
assert config["corpora"]["train"]["path"] == "${paths.train}" assert config["corpora"]["train"]["path"] == "${paths.train}"
interpolated = config.interpolate() interpolated = config.interpolate()
assert interpolated["corpora"]["train"]["path"] == "" assert interpolated["corpora"]["train"]["path"] is None
nlp = English.from_config(config) nlp = English.from_config(config)
assert nlp.config["corpora"]["train"]["path"] == "${paths.train}" assert nlp.config["corpora"]["train"]["path"] == "${paths.train}"
# Ensure that variables are preserved in nlp config # Ensure that variables are preserved in nlp config
@ -317,10 +317,10 @@ def test_config_interpolation():
assert config["components"]["tagger"]["model"]["tok2vec"]["width"] == width assert config["components"]["tagger"]["model"]["tok2vec"]["width"] == width
assert nlp.config["components"]["tagger"]["model"]["tok2vec"]["width"] == width assert nlp.config["components"]["tagger"]["model"]["tok2vec"]["width"] == width
interpolated2 = nlp.config.interpolate() interpolated2 = nlp.config.interpolate()
assert interpolated2["corpora"]["train"]["path"] == "" assert interpolated2["corpora"]["train"]["path"] is None
assert interpolated2["components"]["tagger"]["model"]["tok2vec"]["width"] == 342 assert interpolated2["components"]["tagger"]["model"]["tok2vec"]["width"] == 342
nlp2 = English.from_config(interpolated) nlp2 = English.from_config(interpolated)
assert nlp2.config["corpora"]["train"]["path"] == "" assert nlp2.config["corpora"]["train"]["path"] is None
assert nlp2.config["components"]["tagger"]["model"]["tok2vec"]["width"] == 342 assert nlp2.config["components"]["tagger"]["model"]["tok2vec"]["width"] == 342

View File

@ -7,7 +7,7 @@ import srsly
from .. import util from .. import util
from .augment import dont_augment from .augment import dont_augment
from .example import Example from .example import Example
from ..errors import Warnings from ..errors import Warnings, Errors
from ..tokens import DocBin, Doc from ..tokens import DocBin, Doc
from ..vocab import Vocab from ..vocab import Vocab
@ -20,12 +20,14 @@ FILE_TYPE = ".spacy"
@util.registry.readers("spacy.Corpus.v1") @util.registry.readers("spacy.Corpus.v1")
def create_docbin_reader( def create_docbin_reader(
path: Path, path: Optional[Path],
gold_preproc: bool, gold_preproc: bool,
max_length: int = 0, max_length: int = 0,
limit: int = 0, limit: int = 0,
augmenter: Optional[Callable] = None, augmenter: Optional[Callable] = None,
) -> Callable[["Language"], Iterable[Example]]: ) -> Callable[["Language"], Iterable[Example]]:
if path is None:
raise ValueError(Errors.E913)
util.logger.debug(f"Loading corpus from path: {path}") util.logger.debug(f"Loading corpus from path: {path}")
return Corpus( return Corpus(
path, path,