mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 17:06:29 +03:00
Create corpus iterator and batcher from registry during training (#5865)
* Move batchers into their own module (and registry) * Update CLI * Update Corpus and batcher * Update tests * Update one config * Merge 'evaluation' block back under [training] * Import batchers in gold __init__ * Fix batchers * Update config * Update schema * Update util * Don't assume train and dev are actually paths * Update onto-joint config * Fix missing import * Format * Format * Update spacy/gold/corpus.py Co-authored-by: Ines Montani <ines@ines.io> * Fix name * Update default config * Fix get_length option in batchers * Update test * Add comment * Pass path into Corpus * Update docstring * Update schema and configs * Update config * Fix test * Fix paths * Fix print * Fix create_train_batches * [training.read_train] -> [training.train_corpus] * Update onto-joint config Co-authored-by: Ines Montani <ines@ines.io>
This commit is contained in:
parent
82347110f5
commit
ecb3c4e8f4
|
@ -1,37 +1,45 @@
|
|||
# Training hyper-parameters and additional features.
|
||||
[training]
|
||||
# Whether to train on sequences with 'gold standard' sentence boundaries
|
||||
# and tokens. If you set this to true, take care to ensure your run-time
|
||||
# data is passed in sentence-by-sentence via some prior preprocessing.
|
||||
gold_preproc = false
|
||||
# Limitations on training document length or number of examples.
|
||||
max_length = 5000
|
||||
limit = 0
|
||||
# Data augmentation
|
||||
orth_variant_level = 0.0
|
||||
dropout = 0.1
|
||||
# Controls early-stopping. 0 or -1 mean unlimited.
|
||||
patience = 1600
|
||||
max_epochs = 0
|
||||
max_steps = 20000
|
||||
eval_frequency = 200
|
||||
# Other settings
|
||||
seed = 0
|
||||
accumulate_gradient = 1
|
||||
use_pytorch_for_gpu_memory = false
|
||||
# Control how scores are printed and checkpoints are evaluated.
|
||||
eval_batch_size = 128
|
||||
score_weights = {"dep_las": 0.4, "ents_f": 0.4, "tag_acc": 0.2}
|
||||
[paths]
|
||||
train = ""
|
||||
dev = ""
|
||||
raw = null
|
||||
init_tok2vec = null
|
||||
discard_oversize = false
|
||||
batch_by = "words"
|
||||
raw_text = null
|
||||
tag_map = null
|
||||
vectors = null
|
||||
base_model = null
|
||||
morph_rules = null
|
||||
|
||||
[training.batch_size]
|
||||
[system]
|
||||
seed = 0
|
||||
use_pytorch_for_gpu_memory = false
|
||||
|
||||
[training]
|
||||
seed = ${system:seed}
|
||||
dropout = 0.1
|
||||
init_tok2vec = ${paths:init_tok2vec}
|
||||
vectors = null
|
||||
accumulate_gradient = 1
|
||||
max_steps = 0
|
||||
max_epochs = 0
|
||||
patience = 10000
|
||||
eval_frequency = 200
|
||||
score_weights = {"dep_las": 0.4, "ents_f": 0.4, "tag_acc": 0.2}
|
||||
|
||||
[training.train_corpus]
|
||||
@readers = "spacy.Corpus.v1"
|
||||
path = ${paths:train}
|
||||
gold_preproc = true
|
||||
max_length = 0
|
||||
limit = 0
|
||||
|
||||
[training.dev_corpus]
|
||||
@readers = "spacy.Corpus.v1"
|
||||
path = ${paths:dev}
|
||||
gold_preproc = ${training.read_train:gold_preproc}
|
||||
max_length = 0
|
||||
limit = 0
|
||||
|
||||
[training.batcher]
|
||||
@batchers = "batch_by_words.v1"
|
||||
discard_oversize = false
|
||||
tolerance = 0.2
|
||||
|
||||
[training.batcher.size]
|
||||
@schedules = "compounding.v1"
|
||||
start = 100
|
||||
stop = 1000
|
||||
|
|
|
@ -1,30 +1,45 @@
|
|||
[paths]
|
||||
train = ""
|
||||
dev = ""
|
||||
raw = null
|
||||
init_tok2vec = null
|
||||
|
||||
[system]
|
||||
seed = 0
|
||||
use_pytorch_for_gpu_memory = false
|
||||
|
||||
[training]
|
||||
seed = ${system:seed}
|
||||
dropout = 0.2
|
||||
init_tok2vec = ${paths:init_tok2vec}
|
||||
vectors = null
|
||||
accumulate_gradient = 1
|
||||
max_steps = 0
|
||||
max_epochs = 0
|
||||
patience = 10000
|
||||
eval_frequency = 200
|
||||
dropout = 0.2
|
||||
init_tok2vec = null
|
||||
vectors = null
|
||||
max_epochs = 100
|
||||
orth_variant_level = 0.0
|
||||
score_weights = {"dep_las": 0.8, "tag_acc": 0.2}
|
||||
|
||||
[training.read_train]
|
||||
@readers = "spacy.Corpus.v1"
|
||||
path = ${paths:train}
|
||||
gold_preproc = true
|
||||
max_length = 0
|
||||
scores = ["tag_acc", "dep_uas", "dep_las", "speed"]
|
||||
score_weights = {"dep_las": 0.8, "tag_acc": 0.2}
|
||||
limit = 0
|
||||
seed = 0
|
||||
accumulate_gradient = 1
|
||||
|
||||
[training.read_dev]
|
||||
@readers = "spacy.Corpus.v1"
|
||||
path = ${paths:dev}
|
||||
gold_preproc = ${training.read_train:gold_preproc}
|
||||
max_length = 0
|
||||
limit = 0
|
||||
|
||||
[training.batcher]
|
||||
@batchers = "batch_by_words.v1"
|
||||
discard_oversize = false
|
||||
raw_text = null
|
||||
tag_map = null
|
||||
morph_rules = null
|
||||
base_model = null
|
||||
tolerance = 0.2
|
||||
|
||||
eval_batch_size = 128
|
||||
use_pytorch_for_gpu_memory = false
|
||||
batch_by = "words"
|
||||
|
||||
[training.batch_size]
|
||||
[training.batcher.size]
|
||||
@schedules = "compounding.v1"
|
||||
start = 100
|
||||
stop = 1000
|
||||
|
|
|
@ -162,13 +162,12 @@ def debug_data(
|
|||
loading_train_error_message = ""
|
||||
loading_dev_error_message = ""
|
||||
with msg.loading("Loading corpus..."):
|
||||
corpus = Corpus(train_path, dev_path)
|
||||
try:
|
||||
train_dataset = list(corpus.train_dataset(nlp))
|
||||
train_dataset = list(Corpus(train_path)(nlp))
|
||||
except ValueError as e:
|
||||
loading_train_error_message = f"Training data cannot be loaded: {e}"
|
||||
try:
|
||||
dev_dataset = list(corpus.dev_dataset(nlp))
|
||||
dev_dataset = list(Corpus(dev_path)(nlp))
|
||||
except ValueError as e:
|
||||
loading_dev_error_message = f"Development data cannot be loaded: {e}"
|
||||
if loading_train_error_message or loading_dev_error_message:
|
||||
|
|
|
@ -64,9 +64,9 @@ def evaluate(
|
|||
msg.fail("Evaluation data not found", data_path, exits=1)
|
||||
if displacy_path and not displacy_path.exists():
|
||||
msg.fail("Visualization output directory not found", displacy_path, exits=1)
|
||||
corpus = Corpus(data_path, data_path)
|
||||
corpus = Corpus(data_path, gold_preproc=gold_preproc)
|
||||
nlp = util.load_model(model)
|
||||
dev_dataset = list(corpus.dev_dataset(nlp, gold_preproc=gold_preproc))
|
||||
dev_dataset = list(corpus(nlp))
|
||||
scores = nlp.evaluate(dev_dataset, verbose=False)
|
||||
metrics = {
|
||||
"TOK": "token_acc",
|
||||
|
|
|
@ -12,9 +12,9 @@ import typer
|
|||
|
||||
from ._util import app, Arg, Opt, parse_config_overrides, show_validation_error
|
||||
from ._util import import_code
|
||||
from ..gold import Corpus, Example
|
||||
from ..language import Language
|
||||
from .. import util
|
||||
from ..gold.example import Example
|
||||
from ..errors import Errors
|
||||
|
||||
|
||||
|
@ -28,8 +28,6 @@ from ..ml import models # noqa: F401
|
|||
def train_cli(
|
||||
# fmt: off
|
||||
ctx: typer.Context, # This is only used to read additional arguments
|
||||
train_path: Path = Arg(..., help="Location of training data", exists=True),
|
||||
dev_path: Path = Arg(..., help="Location of development data", exists=True),
|
||||
config_path: Path = Arg(..., help="Path to config file", exists=True),
|
||||
output_path: Optional[Path] = Opt(None, "--output", "--output-path", "-o", help="Output directory to store model in"),
|
||||
code_path: Optional[Path] = Opt(None, "--code-path", "-c", help="Path to Python file with additional code (registered functions) to be imported"),
|
||||
|
@ -51,12 +49,11 @@ def train_cli(
|
|||
referenced in the config.
|
||||
"""
|
||||
util.set_env_log(verbose)
|
||||
verify_cli_args(train_path, dev_path, config_path, output_path)
|
||||
verify_cli_args(config_path, output_path)
|
||||
overrides = parse_config_overrides(ctx.args)
|
||||
import_code(code_path)
|
||||
train(
|
||||
config_path,
|
||||
{"train": train_path, "dev": dev_path},
|
||||
output_path=output_path,
|
||||
config_overrides=overrides,
|
||||
use_gpu=use_gpu,
|
||||
|
@ -66,8 +63,6 @@ def train_cli(
|
|||
|
||||
def train(
|
||||
config_path: Path,
|
||||
data_paths: Dict[str, Path],
|
||||
raw_text: Optional[Path] = None,
|
||||
output_path: Optional[Path] = None,
|
||||
config_overrides: Dict[str, Any] = {},
|
||||
use_gpu: int = -1,
|
||||
|
@ -85,36 +80,24 @@ def train(
|
|||
fix_random_seed(config["training"]["seed"])
|
||||
with show_validation_error(config_path):
|
||||
nlp, config = util.load_model_from_config(config, overrides=config_overrides)
|
||||
if config["training"]["base_model"]:
|
||||
# TODO: do something to check base_nlp against regular nlp described in config?
|
||||
# If everything matches it will look something like:
|
||||
# base_nlp = util.load_model(config["training"]["base_model"])
|
||||
# nlp = base_nlp
|
||||
raise NotImplementedError("base_model not supported yet.")
|
||||
if config["training"]["vectors"] is not None:
|
||||
util.load_vectors_into_model(nlp, config["training"]["vectors"])
|
||||
verify_config(nlp)
|
||||
raw_text, tag_map, morph_rules, weights_data = load_from_paths(config)
|
||||
if config["training"]["use_pytorch_for_gpu_memory"]:
|
||||
if config.get("system", {}).get("use_pytorch_for_gpu_memory"):
|
||||
# It feels kind of weird to not have a default for this.
|
||||
use_pytorch_for_gpu_memory()
|
||||
training = config["training"]
|
||||
optimizer = training["optimizer"]
|
||||
limit = training["limit"]
|
||||
corpus = Corpus(data_paths["train"], data_paths["dev"], limit=limit)
|
||||
T_cfg = config["training"]
|
||||
optimizer = T_cfg["optimizer"]
|
||||
train_corpus = T_cfg["train_corpus"]
|
||||
dev_corpus = T_cfg["dev_corpus"]
|
||||
batcher = T_cfg["batcher"]
|
||||
if resume_training:
|
||||
msg.info("Resuming training")
|
||||
nlp.resume_training()
|
||||
else:
|
||||
msg.info(f"Initializing the nlp pipeline: {nlp.pipe_names}")
|
||||
train_examples = corpus.train_dataset(
|
||||
nlp,
|
||||
shuffle=False,
|
||||
gold_preproc=training["gold_preproc"],
|
||||
max_length=training["max_length"],
|
||||
)
|
||||
train_examples = list(train_examples)
|
||||
nlp.begin_training(lambda: train_examples)
|
||||
nlp.begin_training(lambda: train_corpus(nlp))
|
||||
|
||||
if tag_map:
|
||||
# Replace tag map with provided mapping
|
||||
|
@ -140,38 +123,35 @@ def train(
|
|||
msg.fail(err, exits=1)
|
||||
tok2vec.from_bytes(weights_data)
|
||||
|
||||
msg.info("Loading training corpus")
|
||||
train_batches = create_train_batches(nlp, corpus, training)
|
||||
evaluate = create_evaluation_callback(nlp, optimizer, corpus, training)
|
||||
|
||||
# Create iterator, which yields out info after each optimization step.
|
||||
msg.info("Start training")
|
||||
score_weights = T_cfg["score_weights"]
|
||||
training_step_iterator = train_while_improving(
|
||||
nlp,
|
||||
optimizer,
|
||||
train_batches,
|
||||
evaluate,
|
||||
dropout=training["dropout"],
|
||||
accumulate_gradient=training["accumulate_gradient"],
|
||||
patience=training["patience"],
|
||||
max_steps=training["max_steps"],
|
||||
eval_frequency=training["eval_frequency"],
|
||||
raw_text=raw_text,
|
||||
create_train_batches(train_corpus(nlp), batcher, T_cfg["max_epochs"]),
|
||||
create_evaluation_callback(nlp, dev_corpus, score_weights),
|
||||
dropout=T_cfg["dropout"],
|
||||
accumulate_gradient=T_cfg["accumulate_gradient"],
|
||||
patience=T_cfg["patience"],
|
||||
max_steps=T_cfg["max_steps"],
|
||||
eval_frequency=T_cfg["eval_frequency"],
|
||||
raw_text=None
|
||||
)
|
||||
msg.info(f"Training. Initial learn rate: {optimizer.learn_rate}")
|
||||
print_row = setup_printer(training, nlp)
|
||||
print_row = setup_printer(T_cfg, nlp)
|
||||
|
||||
try:
|
||||
progress = tqdm.tqdm(total=training["eval_frequency"], leave=False)
|
||||
progress = tqdm.tqdm(total=T_cfg["eval_frequency"], leave=False)
|
||||
for batch, info, is_best_checkpoint in training_step_iterator:
|
||||
progress.update(1)
|
||||
if is_best_checkpoint is not None:
|
||||
progress.close()
|
||||
print_row(info)
|
||||
if is_best_checkpoint and output_path is not None:
|
||||
update_meta(training, nlp, info)
|
||||
update_meta(T_cfg, nlp, info)
|
||||
nlp.to_disk(output_path / "model-best")
|
||||
progress = tqdm.tqdm(total=training["eval_frequency"], leave=False)
|
||||
progress = tqdm.tqdm(total=T_cfg["eval_frequency"], leave=False)
|
||||
except Exception as e:
|
||||
if output_path is not None:
|
||||
msg.warn(
|
||||
|
@ -192,70 +172,34 @@ def train(
|
|||
msg.good(f"Saved model to output directory {final_model_path}")
|
||||
|
||||
|
||||
def create_train_batches(
|
||||
nlp: Language, corpus: Corpus, cfg: Union[Config, Dict[str, Any]]
|
||||
):
|
||||
max_epochs = cfg["max_epochs"]
|
||||
train_examples = list(
|
||||
corpus.train_dataset(
|
||||
nlp,
|
||||
shuffle=True,
|
||||
gold_preproc=cfg["gold_preproc"],
|
||||
max_length=cfg["max_length"],
|
||||
)
|
||||
)
|
||||
epoch = 0
|
||||
batch_strategy = cfg["batch_by"]
|
||||
while True:
|
||||
if len(train_examples) == 0:
|
||||
raise ValueError(Errors.E988)
|
||||
epoch += 1
|
||||
if batch_strategy == "padded":
|
||||
batches = util.minibatch_by_padded_size(
|
||||
train_examples,
|
||||
size=cfg["batch_size"],
|
||||
buffer=256,
|
||||
discard_oversize=cfg["discard_oversize"],
|
||||
)
|
||||
elif batch_strategy == "words":
|
||||
batches = util.minibatch_by_words(
|
||||
train_examples,
|
||||
size=cfg["batch_size"],
|
||||
discard_oversize=cfg["discard_oversize"],
|
||||
)
|
||||
else:
|
||||
batches = util.minibatch(train_examples, size=cfg["batch_size"])
|
||||
# make sure the minibatch_by_words result is not empty, or we'll have an infinite training loop
|
||||
try:
|
||||
first = next(batches)
|
||||
yield epoch, first
|
||||
except StopIteration:
|
||||
raise ValueError(Errors.E986)
|
||||
for batch in batches:
|
||||
def create_train_batches(iterator, batcher, max_epochs: int):
|
||||
epoch = 1
|
||||
examples = []
|
||||
# Stream the first epoch, so we start training faster and support
|
||||
# infinite streams.
|
||||
for batch in batcher(iterator):
|
||||
yield epoch, batch
|
||||
if max_epochs != 1:
|
||||
examples.extend(batch)
|
||||
if not examples:
|
||||
# Raise error if no data
|
||||
raise ValueError(Errors.E986)
|
||||
while epoch != max_epochs:
|
||||
random.shuffle(examples)
|
||||
for batch in batcher(examples):
|
||||
yield epoch, batch
|
||||
if max_epochs >= 1 and epoch >= max_epochs:
|
||||
break
|
||||
random.shuffle(train_examples)
|
||||
epoch += 1
|
||||
|
||||
|
||||
def create_evaluation_callback(
|
||||
nlp: Language,
|
||||
optimizer: Optimizer,
|
||||
corpus: Corpus,
|
||||
cfg: Union[Config, Dict[str, Any]],
|
||||
dev_corpus: Callable,
|
||||
weights: Dict[str, float],
|
||||
) -> Callable[[], Tuple[float, Dict[str, float]]]:
|
||||
def evaluate() -> Tuple[float, Dict[str, float]]:
|
||||
dev_examples = corpus.dev_dataset(nlp, gold_preproc=cfg["gold_preproc"])
|
||||
dev_examples = list(dev_examples)
|
||||
n_words = sum(len(ex.predicted) for ex in dev_examples)
|
||||
batch_size = cfg["eval_batch_size"]
|
||||
if optimizer.averages:
|
||||
with nlp.use_params(optimizer.averages):
|
||||
scores = nlp.evaluate(dev_examples, batch_size=batch_size)
|
||||
else:
|
||||
scores = nlp.evaluate(dev_examples, batch_size=batch_size)
|
||||
dev_examples = list(dev_corpus(nlp))
|
||||
scores = nlp.evaluate(dev_examples)
|
||||
# Calculate a weighted sum based on score_weights for the main score
|
||||
weights = cfg["score_weights"]
|
||||
try:
|
||||
weighted_score = sum(scores[s] * weights.get(s, 0.0) for s in weights)
|
||||
except KeyError as e:
|
||||
|
@ -348,7 +292,11 @@ def train_while_improving(
|
|||
proc.model.finish_update(optimizer)
|
||||
optimizer.step_schedules()
|
||||
if not (step % eval_frequency):
|
||||
score, other_scores = evaluate()
|
||||
if optimizer.averages:
|
||||
with nlp.use_params(optimizer.averages):
|
||||
score, other_scores = evaluate()
|
||||
else:
|
||||
score, other_scores = evaluate()
|
||||
results.append((score, step))
|
||||
is_best_checkpoint = score == max(results)[0]
|
||||
else:
|
||||
|
@ -459,17 +407,7 @@ def load_from_paths(
|
|||
msg.fail("Can't find raw text", raw_text, exits=1)
|
||||
raw_text = list(srsly.read_jsonl(config["training"]["raw_text"]))
|
||||
tag_map = {}
|
||||
tag_map_path = util.ensure_path(config["training"]["tag_map"])
|
||||
if tag_map_path is not None:
|
||||
if not tag_map_path.exists():
|
||||
msg.fail("Can't find tag map path", tag_map_path, exits=1)
|
||||
tag_map = srsly.read_json(config["training"]["tag_map"])
|
||||
morph_rules = {}
|
||||
morph_rules_path = util.ensure_path(config["training"]["morph_rules"])
|
||||
if morph_rules_path is not None:
|
||||
if not morph_rules_path.exists():
|
||||
msg.fail("Can't find tag map path", morph_rules_path, exits=1)
|
||||
morph_rules = srsly.read_json(config["training"]["morph_rules"])
|
||||
weights_data = None
|
||||
init_tok2vec = util.ensure_path(config["training"]["init_tok2vec"])
|
||||
if init_tok2vec is not None:
|
||||
|
@ -481,18 +419,12 @@ def load_from_paths(
|
|||
|
||||
|
||||
def verify_cli_args(
|
||||
train_path: Path,
|
||||
dev_path: Path,
|
||||
config_path: Path,
|
||||
output_path: Optional[Path] = None,
|
||||
) -> None:
|
||||
# Make sure all files and paths exists if they are needed
|
||||
if not config_path or not config_path.exists():
|
||||
msg.fail("Config file not found", config_path, exits=1)
|
||||
if not train_path or not train_path.exists():
|
||||
msg.fail("Training data not found", train_path, exits=1)
|
||||
if not dev_path or not dev_path.exists():
|
||||
msg.fail("Development data not found", dev_path, exits=1)
|
||||
if output_path is not None:
|
||||
if not output_path.exists():
|
||||
output_path.mkdir()
|
||||
|
|
|
@ -1,3 +1,13 @@
|
|||
[paths]
|
||||
train = ""
|
||||
dev = ""
|
||||
raw = null
|
||||
init_tok2vec = null
|
||||
|
||||
[system]
|
||||
seed = 0
|
||||
use_pytorch_for_gpu_memory = false
|
||||
|
||||
[nlp]
|
||||
lang = null
|
||||
pipeline = []
|
||||
|
@ -13,38 +23,55 @@ load_vocab_data = true
|
|||
|
||||
# Training hyper-parameters and additional features.
|
||||
[training]
|
||||
# Whether to train on sequences with 'gold standard' sentence boundaries
|
||||
# and tokens. If you set this to true, take care to ensure your run-time
|
||||
# data is passed in sentence-by-sentence via some prior preprocessing.
|
||||
gold_preproc = false
|
||||
# Limitations on training document length or number of examples.
|
||||
max_length = 5000
|
||||
limit = 0
|
||||
# Data augmentation
|
||||
orth_variant_level = 0.0
|
||||
seed = ${system:seed}
|
||||
dropout = 0.1
|
||||
accumulate_gradient = 1
|
||||
# Extra resources for transfer-learning or pseudo-rehearsal
|
||||
init_tok2vec = ${paths:init_tok2vec}
|
||||
raw_text = ${paths:raw}
|
||||
vectors = null
|
||||
# Controls early-stopping. 0 or -1 mean unlimited.
|
||||
patience = 1600
|
||||
max_epochs = 0
|
||||
max_steps = 20000
|
||||
eval_frequency = 200
|
||||
eval_batch_size = 128
|
||||
# Other settings
|
||||
seed = 0
|
||||
accumulate_gradient = 1
|
||||
use_pytorch_for_gpu_memory = false
|
||||
# Control how scores are printed and checkpoints are evaluated.
|
||||
score_weights = {}
|
||||
# These settings are invalid for the transformer models.
|
||||
init_tok2vec = null
|
||||
|
||||
[training.train_corpus]
|
||||
@readers = "spacy.Corpus.v1"
|
||||
path = ${paths:train}
|
||||
# Whether to train on sequences with 'gold standard' sentence boundaries
|
||||
# and tokens. If you set this to true, take care to ensure your run-time
|
||||
# data is passed in sentence-by-sentence via some prior preprocessing.
|
||||
gold_preproc = false
|
||||
# Limitations on training document length
|
||||
max_length = 2000
|
||||
# Limitation on number of training examples
|
||||
limit = 0
|
||||
|
||||
[training.dev_corpus]
|
||||
@readers = "spacy.Corpus.v1"
|
||||
path = ${paths:dev}
|
||||
# Whether to train on sequences with 'gold standard' sentence boundaries
|
||||
# and tokens. If you set this to true, take care to ensure your run-time
|
||||
# data is passed in sentence-by-sentence via some prior preprocessing.
|
||||
gold_preproc = false
|
||||
# Limitations on training document length
|
||||
max_length = 2000
|
||||
# Limitation on number of training examples
|
||||
limit = 0
|
||||
|
||||
[training.batcher]
|
||||
@batchers = "batch_by_words.v1"
|
||||
discard_oversize = false
|
||||
raw_text = null
|
||||
tag_map = null
|
||||
morph_rules = null
|
||||
base_model = null
|
||||
vectors = null
|
||||
batch_by = "words"
|
||||
batch_size = 1000
|
||||
tolerance = 0.2
|
||||
|
||||
[training.batcher.size]
|
||||
@schedules = "compounding.v1"
|
||||
start = 100
|
||||
stop = 1000
|
||||
compound = 1.001
|
||||
|
||||
[training.optimizer]
|
||||
@optimizers = "Adam.v1"
|
||||
|
@ -69,8 +96,8 @@ max_length = 500
|
|||
dropout = 0.2
|
||||
n_save_every = null
|
||||
batch_size = 3000
|
||||
seed = ${training:seed}
|
||||
use_pytorch_for_gpu_memory = ${training:use_pytorch_for_gpu_memory}
|
||||
seed = ${system:seed}
|
||||
use_pytorch_for_gpu_memory = ${system:use_pytorch_for_gpu_memory}
|
||||
tok2vec_model = "components.tok2vec.model"
|
||||
|
||||
[pretraining.objective]
|
||||
|
|
|
@ -9,3 +9,6 @@ from .iob_utils import tags_to_entities
|
|||
|
||||
from .gold_io import docs_to_json
|
||||
from .gold_io import read_json_file
|
||||
|
||||
|
||||
from .batchers import minibatch_by_padded_size, minibatch_by_words
|
||||
|
|
172
spacy/gold/batchers.py
Normal file
172
spacy/gold/batchers.py
Normal file
|
@ -0,0 +1,172 @@
|
|||
from typing import Union, Iterator, Iterable, Sequence, TypeVar, List, Callable
|
||||
from typing import Optional, Any
|
||||
from functools import partial
|
||||
import itertools
|
||||
|
||||
from .example import Example
|
||||
from ..util import registry, minibatch
|
||||
|
||||
|
||||
Sizing = Union[Iterable[int], int]
|
||||
ItemT = TypeVar("ItemT")
|
||||
BatcherT = Callable[[Iterable[ItemT]], Iterable[List[ItemT]]]
|
||||
|
||||
|
||||
@registry.batchers("batch_by_padded.v1")
|
||||
def configure_minibatch_by_padded_size(
|
||||
*,
|
||||
size: Sizing,
|
||||
buffer: int,
|
||||
discard_oversize: bool,
|
||||
get_length: Optional[Callable[[ItemT], int]] = None
|
||||
) -> BatcherT:
|
||||
# Avoid displacing optional values from the underlying function.
|
||||
optionals = {"get_length": get_length} if get_length is not None else {}
|
||||
return partial(
|
||||
minibatch_by_padded_size,
|
||||
size=size,
|
||||
buffer=buffer,
|
||||
discard_oversize=discard_oversize,
|
||||
**optionals
|
||||
)
|
||||
|
||||
|
||||
@registry.batchers("batch_by_words.v1")
|
||||
def configure_minibatch_by_words(
|
||||
*,
|
||||
size: Sizing,
|
||||
tolerance: float,
|
||||
discard_oversize: bool,
|
||||
get_length: Optional[Callable[[ItemT], int]] = None
|
||||
) -> BatcherT:
|
||||
optionals = {"get_length": get_length} if get_length is not None else {}
|
||||
return partial(
|
||||
minibatch_by_words,
|
||||
size=size,
|
||||
discard_oversize=discard_oversize,
|
||||
**optionals
|
||||
)
|
||||
|
||||
|
||||
@registry.batchers("batch_by_sequence.v1")
|
||||
def configure_minibatch(size: Sizing, get_length=None) -> BatcherT:
|
||||
optionals = ({"get_length": get_length} if get_length is not None else {})
|
||||
return partial(minibatch, size=size, **optionals)
|
||||
|
||||
|
||||
def minibatch_by_padded_size(
|
||||
docs: Iterator["Doc"],
|
||||
size: Sizing,
|
||||
buffer: int = 256,
|
||||
discard_oversize: bool = False,
|
||||
get_length: Callable = len,
|
||||
) -> Iterator[Iterator["Doc"]]:
|
||||
if isinstance(size, int):
|
||||
size_ = itertools.repeat(size)
|
||||
else:
|
||||
size_ = size
|
||||
for outer_batch in minibatch(docs, size=buffer):
|
||||
outer_batch = list(outer_batch)
|
||||
target_size = next(size_)
|
||||
for indices in _batch_by_length(outer_batch, target_size, get_length):
|
||||
subbatch = [outer_batch[i] for i in indices]
|
||||
padded_size = max(len(seq) for seq in subbatch) * len(subbatch)
|
||||
if discard_oversize and padded_size >= target_size:
|
||||
pass
|
||||
else:
|
||||
yield subbatch
|
||||
|
||||
|
||||
def minibatch_by_words(
|
||||
docs, size, tolerance=0.2, discard_oversize=False, get_length=len
|
||||
):
|
||||
"""Create minibatches of roughly a given number of words. If any examples
|
||||
are longer than the specified batch length, they will appear in a batch by
|
||||
themselves, or be discarded if discard_oversize=True.
|
||||
The argument 'docs' can be a list of strings, Doc's or Example's. """
|
||||
if isinstance(size, int):
|
||||
size_ = itertools.repeat(size)
|
||||
elif isinstance(size, List):
|
||||
size_ = iter(size)
|
||||
else:
|
||||
size_ = size
|
||||
target_size = next(size_)
|
||||
tol_size = target_size * tolerance
|
||||
batch = []
|
||||
overflow = []
|
||||
batch_size = 0
|
||||
overflow_size = 0
|
||||
for doc in docs:
|
||||
n_words = get_length(doc)
|
||||
# if the current example exceeds the maximum batch size, it is returned separately
|
||||
# but only if discard_oversize=False.
|
||||
if n_words > target_size + tol_size:
|
||||
if not discard_oversize:
|
||||
yield [doc]
|
||||
# add the example to the current batch if there's no overflow yet and it still fits
|
||||
elif overflow_size == 0 and (batch_size + n_words) <= target_size:
|
||||
batch.append(doc)
|
||||
batch_size += n_words
|
||||
# add the example to the overflow buffer if it fits in the tolerance margin
|
||||
elif (batch_size + overflow_size + n_words) <= (target_size + tol_size):
|
||||
overflow.append(doc)
|
||||
overflow_size += n_words
|
||||
# yield the previous batch and start a new one. The new one gets the overflow examples.
|
||||
else:
|
||||
if batch:
|
||||
yield batch
|
||||
target_size = next(size_)
|
||||
tol_size = target_size * tolerance
|
||||
batch = overflow
|
||||
batch_size = overflow_size
|
||||
overflow = []
|
||||
overflow_size = 0
|
||||
# this example still fits
|
||||
if (batch_size + n_words) <= target_size:
|
||||
batch.append(doc)
|
||||
batch_size += n_words
|
||||
# this example fits in overflow
|
||||
elif (batch_size + n_words) <= (target_size + tol_size):
|
||||
overflow.append(doc)
|
||||
overflow_size += n_words
|
||||
# this example does not fit with the previous overflow: start another new batch
|
||||
else:
|
||||
if batch:
|
||||
yield batch
|
||||
target_size = next(size_)
|
||||
tol_size = target_size * tolerance
|
||||
batch = [doc]
|
||||
batch_size = n_words
|
||||
batch.extend(overflow)
|
||||
if batch:
|
||||
yield batch
|
||||
|
||||
|
||||
def _batch_by_length(
|
||||
seqs: Sequence[Any], max_words: int, get_length=len
|
||||
) -> List[List[Any]]:
|
||||
"""Given a list of sequences, return a batched list of indices into the
|
||||
list, where the batches are grouped by length, in descending order.
|
||||
|
||||
Batches may be at most max_words in size, defined as max sequence length * size.
|
||||
"""
|
||||
# Use negative index so we can get sort by position ascending.
|
||||
lengths_indices = [(get_length(seq), i) for i, seq in enumerate(seqs)]
|
||||
lengths_indices.sort()
|
||||
batches = []
|
||||
batch = []
|
||||
for length, i in lengths_indices:
|
||||
if not batch:
|
||||
batch.append(i)
|
||||
elif length * (len(batch) + 1) <= max_words:
|
||||
batch.append(i)
|
||||
else:
|
||||
batches.append(batch)
|
||||
batch = [i]
|
||||
if batch:
|
||||
batches.append(batch)
|
||||
# Check lengths match
|
||||
assert sum(len(b) for b in batches) == len(seqs)
|
||||
batches = [list(sorted(batch)) for batch in batches]
|
||||
batches.reverse()
|
||||
return batches
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Union, List, Iterable, Iterator, TYPE_CHECKING
|
||||
from typing import Union, List, Iterable, Iterator, TYPE_CHECKING, Callable, Tuple
|
||||
from pathlib import Path
|
||||
import random
|
||||
|
||||
|
@ -12,26 +12,38 @@ if TYPE_CHECKING:
|
|||
from ..language import Language # noqa: F401
|
||||
|
||||
|
||||
@util.registry.readers("spacy.Corpus.v1")
|
||||
def create_docbin_reader(
|
||||
path: Path, gold_preproc: bool, max_length: int = 0, limit: int = 0
|
||||
) -> Callable[["Language"], Iterable[Example]]:
|
||||
return Corpus(path, gold_preproc=gold_preproc, max_length=max_length, limit=limit)
|
||||
|
||||
|
||||
class Corpus:
|
||||
"""An annotated corpus, reading train and dev datasets from
|
||||
the DocBin (.spacy) format.
|
||||
"""Iterate Example objects from a file or directory of DocBin (.spacy)
|
||||
formated data files.
|
||||
|
||||
path (Path): The directory or filename to read from.
|
||||
gold_preproc (bool): Whether to set up the Example object with gold-standard
|
||||
sentences and tokens for the predictions. Gold preprocessing helps
|
||||
the annotations align to the tokenization, and may result in sequences
|
||||
of more consistent length. However, it may reduce run-time accuracy due
|
||||
to train/test skew. Defaults to False.
|
||||
max_length (int): Maximum document length. Longer documents will be
|
||||
split into sentences, if sentence boundaries are available. Defaults to
|
||||
0, which indicates no limit.
|
||||
limit (int): Limit corpus to a subset of examples, e.g. for debugging.
|
||||
Defaults to 0, which indicates no limit.
|
||||
|
||||
DOCS: https://spacy.io/api/corpus
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, train_loc: Union[str, Path], dev_loc: Union[str, Path], limit: int = 0
|
||||
self, path, *, limit: int = 0, gold_preproc: bool = False, max_length: bool = False,
|
||||
) -> None:
|
||||
"""Create a Corpus.
|
||||
|
||||
train (str / Path): File or directory of training data.
|
||||
dev (str / Path): File or directory of development data.
|
||||
limit (int): Max. number of examples returned.
|
||||
|
||||
DOCS: https://spacy.io/api/corpus#init
|
||||
"""
|
||||
self.train_loc = train_loc
|
||||
self.dev_loc = dev_loc
|
||||
self.path = util.ensure_path(path)
|
||||
self.gold_preproc = gold_preproc
|
||||
self.max_length = max_length
|
||||
self.limit = limit
|
||||
|
||||
@staticmethod
|
||||
|
@ -54,6 +66,22 @@ class Corpus:
|
|||
locs.append(path)
|
||||
return locs
|
||||
|
||||
def __call__(self, nlp: "Language") -> Iterator[Example]:
|
||||
"""Yield examples from the data.
|
||||
|
||||
nlp (Language): The current nlp object.
|
||||
loc (Path): The file or directory to read from.
|
||||
YIELDS (Example): The examples.
|
||||
|
||||
DOCS: https://spacy.io/api/corpus#call
|
||||
"""
|
||||
ref_docs = self.read_docbin(nlp.vocab, self.walk_corpus(self.path))
|
||||
if self.gold_preproc:
|
||||
examples = self.make_examples_gold_preproc(nlp, ref_docs)
|
||||
else:
|
||||
examples = self.make_examples(nlp, ref_docs, self.max_length)
|
||||
yield from examples
|
||||
|
||||
def _make_example(
|
||||
self, nlp: "Language", reference: Doc, gold_preproc: bool
|
||||
) -> Example:
|
||||
|
@ -114,68 +142,3 @@ class Corpus:
|
|||
i += 1
|
||||
if self.limit >= 1 and i >= self.limit:
|
||||
break
|
||||
|
||||
def count_train(self, nlp: "Language") -> int:
|
||||
"""Returns count of words in train examples.
|
||||
|
||||
nlp (Language): The current nlp. object.
|
||||
RETURNS (int): The word count.
|
||||
|
||||
DOCS: https://spacy.io/api/corpus#count_train
|
||||
"""
|
||||
n = 0
|
||||
i = 0
|
||||
for example in self.train_dataset(nlp):
|
||||
n += len(example.predicted)
|
||||
if self.limit >= 0 and i >= self.limit:
|
||||
break
|
||||
i += 1
|
||||
return n
|
||||
|
||||
def train_dataset(
|
||||
self,
|
||||
nlp: "Language",
|
||||
*,
|
||||
shuffle: bool = True,
|
||||
gold_preproc: bool = False,
|
||||
max_length: int = 0
|
||||
) -> Iterator[Example]:
|
||||
"""Yield examples from the training data.
|
||||
|
||||
nlp (Language): The current nlp object.
|
||||
shuffle (bool): Whether to shuffle the examples.
|
||||
gold_preproc (bool): Whether to train on gold-standard sentences and tokens.
|
||||
max_length (int): Maximum document length. Longer documents will be
|
||||
split into sentences, if sentence boundaries are available. 0 for
|
||||
no limit.
|
||||
YIELDS (Example): The examples.
|
||||
|
||||
DOCS: https://spacy.io/api/corpus#train_dataset
|
||||
"""
|
||||
ref_docs = self.read_docbin(nlp.vocab, self.walk_corpus(self.train_loc))
|
||||
if gold_preproc:
|
||||
examples = self.make_examples_gold_preproc(nlp, ref_docs)
|
||||
else:
|
||||
examples = self.make_examples(nlp, ref_docs, max_length)
|
||||
if shuffle:
|
||||
examples = list(examples)
|
||||
random.shuffle(examples)
|
||||
yield from examples
|
||||
|
||||
def dev_dataset(
|
||||
self, nlp: "Language", *, gold_preproc: bool = False
|
||||
) -> Iterator[Example]:
|
||||
"""Yield examples from the development data.
|
||||
|
||||
nlp (Language): The current nlp object.
|
||||
gold_preproc (bool): Whether to train on gold-standard sentences and tokens.
|
||||
YIELDS (Example): The examples.
|
||||
|
||||
DOCS: https://spacy.io/api/corpus#dev_dataset
|
||||
"""
|
||||
ref_docs = self.read_docbin(nlp.vocab, self.walk_corpus(self.dev_loc))
|
||||
if gold_preproc:
|
||||
examples = self.make_examples_gold_preproc(nlp, ref_docs)
|
||||
else:
|
||||
examples = self.make_examples(nlp, ref_docs, max_length=0)
|
||||
yield from examples
|
||||
|
|
|
@ -1,13 +1,18 @@
|
|||
from typing import Dict, List, Union, Optional, Sequence, Any, Callable, Type
|
||||
from typing import Iterable, TypeVar
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel, Field, ValidationError, validator
|
||||
from pydantic import StrictStr, StrictInt, StrictFloat, StrictBool
|
||||
from pydantic import root_validator
|
||||
from collections import defaultdict
|
||||
from thinc.api import Optimizer
|
||||
from pathlib import Path
|
||||
|
||||
from .attrs import NAMES
|
||||
|
||||
ItemT = TypeVar("ItemT")
|
||||
Batcher = Callable[[Iterable[ItemT]], Iterable[List[ItemT]]]
|
||||
|
||||
|
||||
def validate(schema: Type[BaseModel], obj: Dict[str, Any]) -> List[str]:
|
||||
"""Validate data against a given pydantic schema.
|
||||
|
@ -178,32 +183,24 @@ class ModelMetaSchema(BaseModel):
|
|||
# check that against this schema in the test suite to make sure it's always
|
||||
# up to date.
|
||||
|
||||
Reader = Callable[["Language", str], Iterable["Example"]]
|
||||
|
||||
class ConfigSchemaTraining(BaseModel):
|
||||
# fmt: off
|
||||
base_model: Optional[StrictStr] = Field(..., title="The base model to use")
|
||||
vectors: Optional[StrictStr] = Field(..., title="Path to vectors")
|
||||
gold_preproc: StrictBool = Field(..., title="Whether to train on gold-standard sentences and tokens")
|
||||
max_length: StrictInt = Field(..., title="Maximum length of examples (longer examples are divided into sentences if possible)")
|
||||
limit: StrictInt = Field(..., title="Number of examples to use (0 for all)")
|
||||
orth_variant_level: StrictFloat = Field(..., title="Orth variants for data augmentation")
|
||||
train_corpus: Reader = Field(..., title="Reader for the training data")
|
||||
dev_corpus: Reader = Field(..., title="Reader for the dev data")
|
||||
batcher: Batcher = Field(..., title="Batcher for the training data")
|
||||
dropout: StrictFloat = Field(..., title="Dropout rate")
|
||||
patience: StrictInt = Field(..., title="How many steps to continue without improvement in evaluation score")
|
||||
max_epochs: StrictInt = Field(..., title="Maximum number of epochs to train for")
|
||||
max_steps: StrictInt = Field(..., title="Maximum number of update steps to train for")
|
||||
eval_frequency: StrictInt = Field(..., title="How often to evaluate during training (steps)")
|
||||
eval_batch_size: StrictInt = Field(..., title="Evaluation batch size")
|
||||
seed: Optional[StrictInt] = Field(..., title="Random seed")
|
||||
accumulate_gradient: StrictInt = Field(..., title="Whether to divide the batch up into substeps")
|
||||
use_pytorch_for_gpu_memory: StrictBool = Field(..., title="Allocate memory via PyTorch")
|
||||
score_weights: Dict[StrictStr, Union[StrictFloat, StrictInt]] = Field(..., title="Scores to report and their weights for selecting final model")
|
||||
init_tok2vec: Optional[StrictStr] = Field(..., title="Path to pretrained tok2vec weights")
|
||||
discard_oversize: StrictBool = Field(..., title="Whether to skip examples longer than batch size")
|
||||
batch_by: StrictStr = Field(..., title="Batch examples by type")
|
||||
raw_text: Optional[StrictStr] = Field(..., title="Raw text")
|
||||
tag_map: Optional[StrictStr] = Field(..., title="Path to JSON-formatted tag map")
|
||||
morph_rules: Optional[StrictStr] = Field(..., title="Path to morphology rules")
|
||||
batch_size: Union[Sequence[int], int] = Field(..., title="The batch size or batch size schedule")
|
||||
raw_text: Optional[StrictStr] = Field(default=None, title="Raw text")
|
||||
optimizer: Optimizer = Field(..., title="The optimizer to use")
|
||||
# fmt: on
|
||||
|
||||
|
@ -211,6 +208,7 @@ class ConfigSchemaTraining(BaseModel):
|
|||
extra = "forbid"
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
#eval_batch_size: StrictInt = Field(..., title="Evaluation batch size")
|
||||
|
||||
class ConfigSchemaNlp(BaseModel):
|
||||
# fmt: off
|
||||
|
|
|
@ -210,7 +210,7 @@ def test_train_empty():
|
|||
nlp.begin_training()
|
||||
for itn in range(2):
|
||||
losses = {}
|
||||
batches = util.minibatch(train_examples)
|
||||
batches = util.minibatch(train_examples, size=8)
|
||||
for batch in batches:
|
||||
nlp.update(batch, losses=losses)
|
||||
|
||||
|
|
|
@ -438,9 +438,8 @@ def test_issue4402():
|
|||
data = DocBin(docs=docs, attrs=attrs).to_bytes()
|
||||
with output_file.open("wb") as file_:
|
||||
file_.write(data)
|
||||
corpus = Corpus(train_loc=str(output_file), dev_loc=str(output_file))
|
||||
|
||||
train_data = list(corpus.train_dataset(nlp))
|
||||
reader = Corpus(output_file)
|
||||
train_data = list(reader(nlp))
|
||||
assert len(train_data) == 2
|
||||
|
||||
split_train_data = []
|
||||
|
|
|
@ -11,8 +11,23 @@ from ..util import make_tempdir
|
|||
|
||||
|
||||
nlp_config_string = """
|
||||
[paths]
|
||||
train = ""
|
||||
dev = ""
|
||||
|
||||
[training]
|
||||
batch_size = 666
|
||||
|
||||
[training.train_corpus]
|
||||
@readers = "spacy.Corpus.v1"
|
||||
path = ${paths:train}
|
||||
|
||||
[training.dev_corpus]
|
||||
@readers = "spacy.Corpus.v1"
|
||||
path = ${paths:dev}
|
||||
|
||||
[training.batcher]
|
||||
@batchers = "batch_by_words.v1"
|
||||
size = 666
|
||||
|
||||
[nlp]
|
||||
lang = "en"
|
||||
|
@ -93,7 +108,7 @@ def test_create_nlp_from_config():
|
|||
with pytest.raises(ConfigValidationError):
|
||||
nlp, _ = load_model_from_config(config, auto_fill=False)
|
||||
nlp, resolved = load_model_from_config(config, auto_fill=True)
|
||||
assert nlp.config["training"]["batch_size"] == 666
|
||||
assert nlp.config["training"]["batcher"]["size"] == 666
|
||||
assert len(nlp.config["training"]) > 1
|
||||
assert nlp.pipe_names == ["tok2vec", "tagger"]
|
||||
assert len(nlp.config["components"]) == 2
|
||||
|
|
|
@ -483,14 +483,14 @@ def test_roundtrip_docs_to_docbin(doc):
|
|||
reloaded_nlp = English()
|
||||
json_file = tmpdir / "roundtrip.json"
|
||||
srsly.write_json(json_file, [docs_to_json(doc)])
|
||||
goldcorpus = Corpus(str(json_file), str(json_file))
|
||||
output_file = tmpdir / "roundtrip.spacy"
|
||||
data = DocBin(docs=[doc]).to_bytes()
|
||||
with output_file.open("wb") as file_:
|
||||
file_.write(data)
|
||||
goldcorpus = Corpus(train_loc=str(output_file), dev_loc=str(output_file))
|
||||
reloaded_example = next(goldcorpus.dev_dataset(nlp=reloaded_nlp))
|
||||
assert len(doc) == goldcorpus.count_train(reloaded_nlp)
|
||||
reader = Corpus(output_file)
|
||||
reloaded_examples = list(reader(reloaded_nlp))
|
||||
assert len(doc) == sum(len(eg) for eg in reloaded_examples)
|
||||
reloaded_example = reloaded_examples[0]
|
||||
assert text == reloaded_example.reference.text
|
||||
assert idx == [t.idx for t in reloaded_example.reference]
|
||||
assert tags == [t.tag_ for t in reloaded_example.reference]
|
||||
|
@ -515,10 +515,9 @@ def test_make_orth_variants(doc):
|
|||
data = DocBin(docs=[doc]).to_bytes()
|
||||
with output_file.open("wb") as file_:
|
||||
file_.write(data)
|
||||
goldcorpus = Corpus(train_loc=str(output_file), dev_loc=str(output_file))
|
||||
|
||||
# due to randomness, test only that this runs with no errors for now
|
||||
train_example = next(goldcorpus.train_dataset(nlp))
|
||||
reader = Corpus(output_file)
|
||||
train_example = next(reader(nlp))
|
||||
make_orth_variants_example(nlp, train_example, orth_variant_level=0.2)
|
||||
|
||||
|
||||
|
|
|
@ -3,8 +3,9 @@ import pytest
|
|||
from .util import get_random_doc
|
||||
|
||||
from spacy import util
|
||||
from spacy.util import minibatch_by_words, dot_to_object
|
||||
from spacy.util import dot_to_object
|
||||
from thinc.api import Config, Optimizer
|
||||
from spacy.gold.batchers import minibatch_by_words
|
||||
|
||||
from ..lang.en import English
|
||||
from ..lang.nl import Dutch
|
||||
|
|
160
spacy/util.py
160
spacy/util.py
|
@ -67,6 +67,8 @@ class registry(thinc.registry):
|
|||
lookups = catalogue.create("spacy", "lookups", entry_points=True)
|
||||
displacy_colors = catalogue.create("spacy", "displacy_colors", entry_points=True)
|
||||
assets = catalogue.create("spacy", "assets", entry_points=True)
|
||||
batchers = catalogue.create("spacy", "batchers", entry_points=True)
|
||||
readers = catalogue.create("spacy", "readers", entry_points=True)
|
||||
# These are factories registered via third-party packages and the
|
||||
# spacy_factories entry point. This registry only exists so we can easily
|
||||
# load them via the entry points. The "true" factories are added via the
|
||||
|
@ -747,144 +749,6 @@ def normalize_slice(
|
|||
return start, stop
|
||||
|
||||
|
||||
def minibatch(
|
||||
items: Iterable[Any], size: Union[Iterator[int], int] = 8
|
||||
) -> Iterator[Any]:
|
||||
"""Iterate over batches of items. `size` may be an iterator,
|
||||
so that batch-size can vary on each step.
|
||||
"""
|
||||
if isinstance(size, int):
|
||||
size_ = itertools.repeat(size)
|
||||
else:
|
||||
size_ = size
|
||||
items = iter(items)
|
||||
while True:
|
||||
batch_size = next(size_)
|
||||
batch = list(itertools.islice(items, int(batch_size)))
|
||||
if len(batch) == 0:
|
||||
break
|
||||
yield list(batch)
|
||||
|
||||
|
||||
def minibatch_by_padded_size(
|
||||
docs: Iterator["Doc"],
|
||||
size: Union[Iterator[int], int],
|
||||
buffer: int = 256,
|
||||
discard_oversize: bool = False,
|
||||
) -> Iterator[Iterator["Doc"]]:
|
||||
if isinstance(size, int):
|
||||
size_ = itertools.repeat(size)
|
||||
else:
|
||||
size_ = size
|
||||
for outer_batch in minibatch(docs, buffer):
|
||||
outer_batch = list(outer_batch)
|
||||
target_size = next(size_)
|
||||
for indices in _batch_by_length(outer_batch, target_size):
|
||||
subbatch = [outer_batch[i] for i in indices]
|
||||
padded_size = max(len(seq) for seq in subbatch) * len(subbatch)
|
||||
if discard_oversize and padded_size >= target_size:
|
||||
pass
|
||||
else:
|
||||
yield subbatch
|
||||
|
||||
|
||||
def _batch_by_length(seqs: Sequence[Any], max_words: int) -> List[List[Any]]:
|
||||
"""Given a list of sequences, return a batched list of indices into the
|
||||
list, where the batches are grouped by length, in descending order.
|
||||
|
||||
Batches may be at most max_words in size, defined as max sequence length * size.
|
||||
"""
|
||||
# Use negative index so we can get sort by position ascending.
|
||||
lengths_indices = [(len(seq), i) for i, seq in enumerate(seqs)]
|
||||
lengths_indices.sort()
|
||||
batches = []
|
||||
batch = []
|
||||
for length, i in lengths_indices:
|
||||
if not batch:
|
||||
batch.append(i)
|
||||
elif length * (len(batch) + 1) <= max_words:
|
||||
batch.append(i)
|
||||
else:
|
||||
batches.append(batch)
|
||||
batch = [i]
|
||||
if batch:
|
||||
batches.append(batch)
|
||||
# Check lengths match
|
||||
assert sum(len(b) for b in batches) == len(seqs)
|
||||
batches = [list(sorted(batch)) for batch in batches]
|
||||
batches.reverse()
|
||||
return batches
|
||||
|
||||
|
||||
def minibatch_by_words(docs, size, tolerance=0.2, discard_oversize=False):
|
||||
"""Create minibatches of roughly a given number of words. If any examples
|
||||
are longer than the specified batch length, they will appear in a batch by
|
||||
themselves, or be discarded if discard_oversize=True.
|
||||
The argument 'docs' can be a list of strings, Doc's or Example's. """
|
||||
from .gold import Example
|
||||
|
||||
if isinstance(size, int):
|
||||
size_ = itertools.repeat(size)
|
||||
elif isinstance(size, List):
|
||||
size_ = iter(size)
|
||||
else:
|
||||
size_ = size
|
||||
target_size = next(size_)
|
||||
tol_size = target_size * tolerance
|
||||
batch = []
|
||||
overflow = []
|
||||
batch_size = 0
|
||||
overflow_size = 0
|
||||
for doc in docs:
|
||||
if isinstance(doc, Example):
|
||||
n_words = len(doc.reference)
|
||||
elif isinstance(doc, str):
|
||||
n_words = len(doc.split())
|
||||
else:
|
||||
n_words = len(doc)
|
||||
# if the current example exceeds the maximum batch size, it is returned separately
|
||||
# but only if discard_oversize=False.
|
||||
if n_words > target_size + tol_size:
|
||||
if not discard_oversize:
|
||||
yield [doc]
|
||||
# add the example to the current batch if there's no overflow yet and it still fits
|
||||
elif overflow_size == 0 and (batch_size + n_words) <= target_size:
|
||||
batch.append(doc)
|
||||
batch_size += n_words
|
||||
# add the example to the overflow buffer if it fits in the tolerance margin
|
||||
elif (batch_size + overflow_size + n_words) <= (target_size + tol_size):
|
||||
overflow.append(doc)
|
||||
overflow_size += n_words
|
||||
# yield the previous batch and start a new one. The new one gets the overflow examples.
|
||||
else:
|
||||
if batch:
|
||||
yield batch
|
||||
target_size = next(size_)
|
||||
tol_size = target_size * tolerance
|
||||
batch = overflow
|
||||
batch_size = overflow_size
|
||||
overflow = []
|
||||
overflow_size = 0
|
||||
# this example still fits
|
||||
if (batch_size + n_words) <= target_size:
|
||||
batch.append(doc)
|
||||
batch_size += n_words
|
||||
# this example fits in overflow
|
||||
elif (batch_size + n_words) <= (target_size + tol_size):
|
||||
overflow.append(doc)
|
||||
overflow_size += n_words
|
||||
# this example does not fit with the previous overflow: start another new batch
|
||||
else:
|
||||
if batch:
|
||||
yield batch
|
||||
target_size = next(size_)
|
||||
tol_size = target_size * tolerance
|
||||
batch = [doc]
|
||||
batch_size = n_words
|
||||
batch.extend(overflow)
|
||||
if batch:
|
||||
yield batch
|
||||
|
||||
|
||||
def filter_spans(spans: Iterable["Span"]) -> List["Span"]:
|
||||
"""Filter a sequence of spans and remove duplicates or overlaps. Useful for
|
||||
|
@ -1217,3 +1081,23 @@ def create_default_optimizer() -> Optimizer:
|
|||
L2_is_weight_decay=L2_is_weight_decay,
|
||||
)
|
||||
return optimizer
|
||||
|
||||
|
||||
def minibatch(items, size):
|
||||
"""Iterate over batches of items. `size` may be an iterator,
|
||||
so that batch-size can vary on each step.
|
||||
"""
|
||||
if isinstance(size, int):
|
||||
size_ = itertools.repeat(size)
|
||||
else:
|
||||
size_ = size
|
||||
items = iter(items)
|
||||
while True:
|
||||
batch_size = next(size_)
|
||||
batch = list(itertools.islice(items, int(batch_size)))
|
||||
if len(batch) == 0:
|
||||
break
|
||||
yield list(batch)
|
||||
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user