Improve v3 pretrain command (#6040)

* Starts to run

* Update pretrain script

* Update corpus

* Update pretrain schema

* Remove outdated test

* Make JsonlTexts produce Example objects.
This commit is contained in:
Matthew Honnibal 2020-09-13 14:05:05 +02:00 committed by GitHub
parent 1316071086
commit 54c40223a1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 158 additions and 187 deletions

View File

@ -1,10 +1,10 @@
from typing import Optional, Dict, Any
import random
from typing import Optional
import numpy
import time
import re
from collections import Counter
from pathlib import Path
from thinc.api import Config
from thinc.api import use_pytorch_for_gpu_memory, require_gpu
from thinc.api import set_dropout_rate, to_categorical, fix_random_seed
from thinc.api import CosineDistance, L2Distance
@ -15,11 +15,10 @@ import typer
from ._util import app, Arg, Opt, parse_config_overrides, show_validation_error
from ._util import import_code
from ..errors import Errors
from ..ml.models.multi_task import build_cloze_multi_task_model
from ..ml.models.multi_task import build_cloze_characters_multi_task_model
from ..tokens import Doc
from ..attrs import ID, HEAD
from ..attrs import ID
from .. import util
@ -30,9 +29,8 @@ from .. import util
def pretrain_cli(
# fmt: off
ctx: typer.Context, # This is only used to read additional arguments
texts_loc: Path = Arg(..., help="Path to JSONL file with raw texts to learn from, with text provided as the key 'text' or tokens as the key 'tokens'", exists=True),
output_dir: Path = Arg(..., help="Directory to write weights to on each epoch"),
config_path: Path = Arg(..., help="Path to config file", exists=True, dir_okay=False),
output_dir: Path = Arg(..., help="Directory to write weights to on each epoch"),
code_path: Optional[Path] = Opt(None, "--code-path", "-c", help="Path to Python file with additional code (registered functions) to be imported"),
resume_path: Optional[Path] = Opt(None, "--resume-path", "-r", help="Path to pretrained weights from which to resume pretraining"),
epoch_resume: Optional[int] = Opt(None, "--epoch-resume", "-er", help="The epoch to resume counting from when using --resume-path. Prevents unintended overwriting of existing weight files."),
@ -60,13 +58,35 @@ def pretrain_cli(
DOCS: https://nightly.spacy.io/api/cli#pretrain
"""
overrides = parse_config_overrides(ctx.args)
config_overrides = parse_config_overrides(ctx.args)
import_code(code_path)
verify_cli_args(config_path, output_dir, resume_path, epoch_resume)
if use_gpu >= 0:
msg.info("Using GPU")
require_gpu(use_gpu)
else:
msg.info("Using CPU")
msg.info(f"Loading config from: {config_path}")
with show_validation_error(config_path):
config = util.load_config(
config_path,
overrides=config_overrides,
interpolate=True
)
if not config.get("pretraining"):
# TODO: What's the solution here? How do we handle optional blocks?
msg.fail("The [pretraining] block in your config is empty", exits=1)
if not output_dir.exists():
output_dir.mkdir()
msg.good(f"Created output directory: {output_dir}")
config.to_disk(output_dir / "config.cfg")
msg.good("Saved config file in the output directory")
pretrain(
texts_loc,
config,
output_dir,
config_path,
config_overrides=overrides,
resume_path=resume_path,
epoch_resume=epoch_resume,
use_gpu=use_gpu,
@ -74,52 +94,22 @@ def pretrain_cli(
def pretrain(
texts_loc: Path,
config: Config,
output_dir: Path,
config_path: Path,
config_overrides: Dict[str, Any] = {},
resume_path: Optional[Path] = None,
epoch_resume: Optional[int] = None,
use_gpu: int = -1,
use_gpu: int=-1
):
verify_cli_args(texts_loc, output_dir, config_path, resume_path, epoch_resume)
if use_gpu >= 0:
msg.info("Using GPU")
require_gpu(use_gpu)
else:
msg.info("Using CPU")
msg.info(f"Loading config from: {config_path}")
with show_validation_error(config_path):
config = util.load_config(config_path, overrides=config_overrides)
nlp, config = util.load_model_from_config(config)
pretrain_config = config["pretraining"]
if not pretrain_config:
# TODO: What's the solution here? How do we handle optional blocks?
msg.fail("The [pretraining] block in your config is empty", exits=1)
if not output_dir.exists():
output_dir.mkdir()
msg.good(f"Created output directory: {output_dir}")
seed = pretrain_config["seed"]
if seed is not None:
fix_random_seed(seed)
if use_gpu >= 0 and pretrain_config["use_pytorch_for_gpu_memory"]:
if config["system"].get("seed") is not None:
fix_random_seed(config["system"]["seed"])
if use_gpu >= 0 and config["system"].get("use_pytorch_for_gpu_memory"):
use_pytorch_for_gpu_memory()
config.to_disk(output_dir / "config.cfg")
msg.good("Saved config file in the output directory")
if texts_loc != "-": # reading from a file
with msg.loading("Loading input texts..."):
texts = list(srsly.read_jsonl(texts_loc))
random.shuffle(texts)
else: # reading from stdin
msg.info("Reading input text from stdin...")
texts = srsly.read_jsonl("-")
tok2vec_path = pretrain_config["tok2vec_model"]
tok2vec = config
for subpath in tok2vec_path.split("."):
tok2vec = tok2vec.get(subpath)
model = create_pretraining_model(nlp, tok2vec, pretrain_config)
optimizer = pretrain_config["optimizer"]
nlp, config = util.load_model_from_config(config)
P_cfg = config["pretraining"]
corpus = P_cfg["corpus"]
batcher = P_cfg["batcher"]
model = create_pretraining_model(nlp, config["pretraining"])
optimizer = config["pretraining"]["optimizer"]
# Load in pretrained weights to resume from
if resume_path is not None:
@ -147,38 +137,35 @@ def pretrain(
with (output_dir / "log.jsonl").open("a") as file_:
file_.write(srsly.json_dumps(log) + "\n")
skip_counter = 0
objective = create_objective(pretrain_config["objective"])
for epoch in range(epoch_resume, pretrain_config["max_epochs"]):
batches = util.minibatch_by_words(texts, size=pretrain_config["batch_size"])
for batch_id, batch in enumerate(batches):
docs, count = make_docs(
nlp,
batch,
max_length=pretrain_config["max_length"],
min_length=pretrain_config["min_length"],
)
skip_counter += count
objective = create_objective(P_cfg["objective"])
# TODO: I think we probably want this to look more like the
# 'create_train_batches' function?
for epoch in range(epoch_resume, P_cfg["max_epochs"]):
for batch_id, batch in enumerate(batcher(corpus(nlp))):
docs = ensure_docs(batch)
loss = make_update(model, docs, optimizer, objective)
progress = tracker.update(epoch, loss, docs)
if progress:
msg.row(progress, **row_settings)
if texts_loc == "-" and tracker.words_per_epoch[epoch] >= 10 ** 7:
break
if pretrain_config["n_save_every"] and (
batch_id % pretrain_config["n_save_every"] == 0
if P_cfg["n_save_every"] and (
batch_id % P_cfg["n_save_every"] == 0
):
_save_model(epoch, is_temp=True)
_save_model(epoch)
tracker.epoch_loss = 0.0
if texts_loc != "-":
# Reshuffle the texts if texts were loaded from a file
random.shuffle(texts)
if skip_counter > 0:
msg.warn(f"Skipped {skip_counter} empty values")
msg.good("Successfully finished pretrain")
def ensure_docs(examples_or_docs):
docs = []
for eg_or_doc in examples_or_docs:
if isinstance(eg_or_doc, Doc):
docs.append(eg_or_doc)
else:
docs.append(eg_or_doc.reference)
return docs
def _resume_model(model, resume_path, epoch_resume):
msg.info(f"Resume training tok2vec from: {resume_path}")
with resume_path.open("rb") as file_:
@ -211,36 +198,6 @@ def make_update(model, docs, optimizer, objective_func):
return float(loss)
def make_docs(nlp, batch, min_length, max_length):
docs = []
skip_count = 0
for record in batch:
if not isinstance(record, dict):
raise TypeError(Errors.E137.format(type=type(record), line=record))
if "tokens" in record:
words = record["tokens"]
if not words:
skip_count += 1
continue
doc = Doc(nlp.vocab, words=words)
elif "text" in record:
text = record["text"]
if not text:
skip_count += 1
continue
doc = nlp.make_doc(text)
else:
raise ValueError(Errors.E138.format(text=record))
if "heads" in record:
heads = record["heads"]
heads = numpy.asarray(heads, dtype="uint64")
heads = heads.reshape((len(doc), 1))
doc = doc.from_array([HEAD], heads)
if min_length <= len(doc) < max_length:
docs.append(doc)
return docs, skip_count
def create_objective(config):
"""Create the objective for pretraining.
@ -296,7 +253,7 @@ def get_characters_loss(ops, docs, prediction, nr_char):
return loss, d_target
def create_pretraining_model(nlp, tok2vec, pretrain_config):
def create_pretraining_model(nlp, pretrain_config):
"""Define a network for the pretraining. We simply add an output layer onto
the tok2vec input model. The tok2vec input model needs to be a model that
takes a batch of Doc objects (as a list), and returns a list of arrays.
@ -304,6 +261,12 @@ def create_pretraining_model(nlp, tok2vec, pretrain_config):
The actual tok2vec layer is stored as a reference, and only this bit will be
serialized to file and read back in when calling the 'train' command.
"""
component = nlp.get_pipe(pretrain_config["component"])
if pretrain_config.get("layer"):
tok2vec = component.model.get_ref(pretrain_config["layer"])
else:
tok2vec = component.model
# TODO
maxout_pieces = 3
hidden_size = 300
@ -372,7 +335,7 @@ def _smart_round(figure, width=10, max_decimal=4):
return format_str % figure
def verify_cli_args(texts_loc, output_dir, config_path, resume_path, epoch_resume):
def verify_cli_args(config_path, output_dir, resume_path, epoch_resume):
if not config_path or not config_path.exists():
msg.fail("Config file not found", config_path, exits=1)
if output_dir.exists() and [p for p in output_dir.iterdir()]:
@ -388,16 +351,6 @@ def verify_cli_args(texts_loc, output_dir, config_path, resume_path, epoch_resum
"It is better to use an empty directory or refer to a new output path, "
"then the new directory will be created for you.",
)
if texts_loc != "-": # reading from a file
texts_loc = Path(texts_loc)
if not texts_loc.exists():
msg.fail("Input text file doesn't exist", texts_loc, exits=1)
for text in srsly.read_jsonl(texts_loc):
break
else:
msg.fail("Input file is empty", texts_loc, exits=1)
if resume_path is not None:
model_name = re.search(r"model\d+\.bin", str(resume_path))
if not model_name and not epoch_resume:

View File

@ -246,15 +246,14 @@ class ConfigSchemaPretrainEmpty(BaseModel):
class ConfigSchemaPretrain(BaseModel):
# fmt: off
max_epochs: StrictInt = Field(..., title="Maximum number of epochs to train for")
min_length: StrictInt = Field(..., title="Minimum length of examples")
max_length: StrictInt = Field(..., title="Maximum length of examples")
dropout: StrictFloat = Field(..., title="Dropout rate")
n_save_every: Optional[StrictInt] = Field(..., title="Saving frequency")
batch_size: Union[Sequence[int], int] = Field(..., title="The batch size or batch size schedule")
seed: Optional[StrictInt] = Field(..., title="Random seed")
use_pytorch_for_gpu_memory: StrictBool = Field(..., title="Allocate memory via PyTorch")
tok2vec_model: StrictStr = Field(..., title="tok2vec model in config, e.g. components.tok2vec.model")
optimizer: Optimizer = Field(..., title="The optimizer to use")
corpus: Reader = Field(..., title="Reader for the training data")
batcher: Batcher = Field(..., title="Batcher for the training data")
component: str = Field(..., title="Component to find the layer to pretrain")
layer: str = Field(..., title="Layer to pretrain. Whole model if empty.")
# TODO: use a more detailed schema for this?
objective: Dict[str, Any] = Field(..., title="Pretraining objective")
# fmt: on

View File

@ -5,7 +5,6 @@ from spacy.training import docs_to_json, biluo_tags_from_offsets
from spacy.training.converters import iob2docs, conll_ner2docs, conllu2docs
from spacy.lang.en import English
from spacy.schemas import ProjectConfigSchema, RecommendationSchema, validate
from spacy.cli.pretrain import make_docs
from spacy.cli.init_config import init_config, RECOMMENDATIONS
from spacy.cli._util import validate_project_commands, parse_config_overrides
from spacy.cli._util import load_project_config, substitute_project_variables
@ -231,48 +230,6 @@ def test_cli_converters_conll_ner2json():
assert ent.text in ["New York City", "London"]
def test_pretrain_make_docs():
nlp = English()
valid_jsonl_text = {"text": "Some text"}
docs, skip_count = make_docs(nlp, [valid_jsonl_text], 1, 10)
assert len(docs) == 1
assert skip_count == 0
valid_jsonl_tokens = {"tokens": ["Some", "tokens"]}
docs, skip_count = make_docs(nlp, [valid_jsonl_tokens], 1, 10)
assert len(docs) == 1
assert skip_count == 0
invalid_jsonl_type = 0
with pytest.raises(TypeError):
make_docs(nlp, [invalid_jsonl_type], 1, 100)
invalid_jsonl_key = {"invalid": "Does not matter"}
with pytest.raises(ValueError):
make_docs(nlp, [invalid_jsonl_key], 1, 100)
empty_jsonl_text = {"text": ""}
docs, skip_count = make_docs(nlp, [empty_jsonl_text], 1, 10)
assert len(docs) == 0
assert skip_count == 1
empty_jsonl_tokens = {"tokens": []}
docs, skip_count = make_docs(nlp, [empty_jsonl_tokens], 1, 10)
assert len(docs) == 0
assert skip_count == 1
too_short_jsonl = {"text": "This text is not long enough"}
docs, skip_count = make_docs(nlp, [too_short_jsonl], 10, 15)
assert len(docs) == 0
assert skip_count == 0
too_long_jsonl = {"text": "This text contains way too much tokens for this test"}
docs, skip_count = make_docs(nlp, [too_long_jsonl], 1, 5)
assert len(docs) == 0
assert skip_count == 0
def test_project_config_validation_full():
config = {
"vars": {"some_var": 20},

View File

@ -1,6 +1,7 @@
import warnings
from typing import Union, List, Iterable, Iterator, TYPE_CHECKING, Callable
from pathlib import Path
import srsly
from .. import util
from .example import Example
@ -21,6 +22,36 @@ def create_docbin_reader(
) -> Callable[["Language"], Iterable[Example]]:
return Corpus(path, gold_preproc=gold_preproc, max_length=max_length, limit=limit)
@util.registry.readers("spacy.JsonlReader.v1")
def create_jsonl_reader(
path: Path, min_length: int=0, max_length: int = 0, limit: int = 0
) -> Callable[["Language"], Iterable[Doc]]:
return JsonlTexts(path, min_length=min_length, max_length=max_length, limit=limit)
def walk_corpus(path: Union[str, Path], file_type) -> List[Path]:
path = util.ensure_path(path)
if not path.is_dir() and path.parts[-1].endswith(file_type):
return [path]
orig_path = path
paths = [path]
locs = []
seen = set()
for path in paths:
if str(path) in seen:
continue
seen.add(str(path))
if path.parts and path.parts[-1].startswith("."):
continue
elif path.is_dir():
paths.extend(path.iterdir())
elif path.parts[-1].endswith(file_type):
locs.append(path)
if len(locs) == 0:
warnings.warn(Warnings.W090.format(path=orig_path))
return locs
class Corpus:
"""Iterate Example objects from a file or directory of DocBin (.spacy)
@ -54,29 +85,6 @@ class Corpus:
self.max_length = max_length
self.limit = limit
@staticmethod
def walk_corpus(path: Union[str, Path]) -> List[Path]:
path = util.ensure_path(path)
if not path.is_dir() and path.parts[-1].endswith(FILE_TYPE):
return [path]
orig_path = path
paths = [path]
locs = []
seen = set()
for path in paths:
if str(path) in seen:
continue
seen.add(str(path))
if path.parts and path.parts[-1].startswith("."):
continue
elif path.is_dir():
paths.extend(path.iterdir())
elif path.parts[-1].endswith(FILE_TYPE):
locs.append(path)
if len(locs) == 0:
warnings.warn(Warnings.W090.format(path=orig_path))
return locs
def __call__(self, nlp: "Language") -> Iterator[Example]:
"""Yield examples from the data.
@ -85,7 +93,7 @@ class Corpus:
DOCS: https://nightly.spacy.io/api/corpus#call
"""
ref_docs = self.read_docbin(nlp.vocab, self.walk_corpus(self.path))
ref_docs = self.read_docbin(nlp.vocab, walk_corpus(self.path, FILE_TYPE))
if self.gold_preproc:
examples = self.make_examples_gold_preproc(nlp, ref_docs)
else:
@ -151,3 +159,57 @@ class Corpus:
i += 1
if self.limit >= 1 and i >= self.limit:
break
class JsonlTexts:
"""Iterate Doc objects from a file or directory of jsonl
formatted raw text files.
path (Path): The directory or filename to read from.
min_length (int): Minimum document length (in tokens). Shorter documents
will be skipped. Defaults to 0, which indicates no limit.
max_length (int): Maximum document length (in tokens). Longer documents will
be skipped. Defaults to 0, which indicates no limit.
limit (int): Limit corpus to a subset of examples, e.g. for debugging.
Defaults to 0, which indicates no limit.
DOCS: https://nightly.spacy.io/api/corpus
"""
file_type = "jsonl"
def __init__(
self,
path: Union[str, Path],
*,
limit: int = 0,
min_length: int = 0,
max_length: int = 0,
) -> None:
self.path = util.ensure_path(path)
self.min_length = min_length
self.max_length = max_length
self.limit = limit
def __call__(self, nlp: "Language") -> Iterator[Example]:
"""Yield examples from the data.
nlp (Language): The current nlp object.
YIELDS (Doc): The docs.
DOCS: https://nightly.spacy.io/api/corpus#call
"""
for loc in walk_corpus(self.path, "jsonl"):
records = srsly.read_jsonl(loc)
for record in records:
doc = nlp.make_doc(record["text"])
if self.min_length >= 1 and len(doc) < self.min_length:
continue
elif self.max_length >= 1 and len(doc) >= self.max_length:
continue
else:
words = [w.text for w in doc]
spaces = [bool(w.whitespace_) for w in doc]
# We don't *need* an example here, but it seems nice to
# make it match the Corpus signature.
yield Example(doc, Doc(nlp.vocab, words=words, spaces=spaces))