Adapt parser and NER for transformers (#5449)

* Draft layer for BILUO actions

* Fixes to biluo layer

* WIP on BILUO layer

* Add tests for BILUO layer

* Format

* Fix transitions

* Update test

* Link in the simple_ner

* Update BILUO tagger

* Update __init__

* Import simple_ner

* Update test

* Import

* Add files

* Add config

* Fix label passing for BILUO and tagger

* Fix label handling for simple_ner component

* Update simple NER test

* Update config

* Hack train script

* Update BILUO layer

* Fix SimpleNER component

* Update train_from_config

* Add biluo_to_iob helper

* Add IOB layer

* Add IOBTagger model

* Update biluo layer

* Update SimpleNER tagger

* Update BILUO

* Read random seed in train-from-config

* Update use of normal_init

* Fix normalization of gradient in SimpleNER

* Update IOBTagger

* Remove print

* Tweak masking in BILUO

* Add dropout in SimpleNER

* Update thinc

* Tidy up simple_ner

* Fix biluo model

* Unhack train-from-config

* Update setup.cfg and requirements

* Add tb_framework.py for parser model

* Try to avoid memory leak in BILUO

* Move ParserModel into spacy.ml, avoid need for subclass.

* Use updated parser model

* Remove incorrect call to model.initializre in PrecomputableAffine

* Update parser model

* Avoid divide by zero in tagger

* Add extra dropout layer in tagger

* Refine minibatch_by_words function to avoid oom

* Fix parser model after refactor

* Try to avoid div-by-zero in SimpleNER

* Fix infinite loop in minibatch_by_words

* Use SequenceCategoricalCrossentropy in Tagger

* Fix parser model when hidden layer

* Remove extra dropout from tagger

* Add extra nan check in tagger

* Fix thinc version

* Update tests and imports

* Fix test

* Update test

* Update tests

* Fix tests

* Fix test

Co-authored-by: Ines Montani <ines@ines.io>
This commit is contained in:
Matthew Honnibal 2020-05-18 22:23:33 +02:00 committed by GitHub
parent 3100c97e69
commit 333b1a308b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
29 changed files with 1180 additions and 247 deletions

View File

@ -4,12 +4,18 @@ limit = 0
dropout = 0.2
patience = 10000
eval_frequency = 200
scores = ["ents_f"]
scores = ["ents_p", "ents_r", "ents_f"]
score_weights = {"ents_f": 1}
orth_variant_level = 0.0
gold_preproc = true
max_length = 0
batch_size = 25
[training.batch_size]
@schedules = "compounding.v1"
start = 3000
stop = 3000
compound = 1.001
[optimizer]
@optimizers = "Adam.v1"
@ -21,45 +27,18 @@ beta2 = 0.999
lang = "en"
vectors = null
[nlp.pipeline.tok2vec]
factory = "tok2vec"
[nlp.pipeline.tok2vec.model]
@architectures = "spacy.Tok2Vec.v1"
[nlp.pipeline.tok2vec.model.extract]
@architectures = "spacy.Doc2Feats.v1"
columns = ["ID", "NORM", "PREFIX", "SUFFIX", "SHAPE", "ORTH"]
[nlp.pipeline.tok2vec.model.embed]
@architectures = "spacy.MultiHashEmbed.v1"
columns = ${nlp.pipeline.tok2vec.model.extract:columns}
width = 96
rows = 2000
use_subwords = true
pretrained_vectors = null
[nlp.pipeline.tok2vec.model.embed.mix]
@architectures = "spacy.LayerNormalizedMaxout.v1"
width = ${nlp.pipeline.tok2vec.model.embed:width}
maxout_pieces = 3
[nlp.pipeline.tok2vec.model.encode]
@architectures = "spacy.MaxoutWindowEncoder.v1"
width = ${nlp.pipeline.tok2vec.model.embed:width}
window_size = 1
maxout_pieces = 3
depth = 2
[nlp.pipeline.ner]
factory = "ner"
factory = "simple_ner"
[nlp.pipeline.ner.model]
@architectures = "spacy.TransitionBasedParser.v1"
nr_feature_tokens = 6
hidden_width = 64
maxout_pieces = 2
@architectures = "spacy.BiluoTagger.v1"
[nlp.pipeline.ner.model.tok2vec]
@architectures = "spacy.Tok2VecTensors.v1"
width = ${nlp.pipeline.tok2vec.model.embed:width}
@architectures = "spacy.HashEmbedCNN.v1"
width = 128
depth = 4
embed_size = 7000
maxout_pieces = 3
window_size = 1
subword_features = true
pretrained_vectors = null

View File

@ -42,26 +42,28 @@ def main(model=None, output_dir=None, n_iter=100):
# create the built-in pipeline components and add them to the pipeline
# nlp.create_pipe works for built-ins that are registered with spaCy
if "ner" not in nlp.pipe_names:
ner = nlp.create_pipe("ner")
if "simple_ner" not in nlp.pipe_names:
ner = nlp.create_pipe("simple_ner")
nlp.add_pipe(ner, last=True)
# otherwise, get it so we can add labels
else:
ner = nlp.get_pipe("ner")
ner = nlp.get_pipe("simple_ner")
# add labels
for _, annotations in TRAIN_DATA:
for ent in annotations.get("entities"):
print("Add label", ent[2])
ner.add_label(ent[2])
# get names of other pipes to disable them during training
pipe_exceptions = ["ner", "trf_wordpiecer", "trf_tok2vec"]
pipe_exceptions = ["simple_ner"]
other_pipes = [pipe for pipe in nlp.pipe_names if pipe not in pipe_exceptions]
with nlp.disable_pipes(*other_pipes): # only train NER
# reset and initialize the weights randomly but only if we're
# training a new model
if model is None:
nlp.begin_training()
print("Transitions", list(enumerate(nlp.get_pipe("simple_ner").get_tag_names())))
for itn in range(n_iter):
random.shuffle(TRAIN_DATA)
losses = {}
@ -70,7 +72,7 @@ def main(model=None, output_dir=None, n_iter=100):
for batch in batches:
nlp.update(
batch,
drop=0.5, # dropout - make it harder to memorise data
drop=0.0, # dropout - make it harder to memorise data
losses=losses,
)
print("Losses", losses)

View File

@ -8,6 +8,7 @@ from wasabi import msg
import thinc
import thinc.schedules
from thinc.api import Model
import random
from ..gold import GoldCorpus
from .. import util
@ -119,6 +120,7 @@ class ConfigSchema(BaseModel):
output_path=("Output directory to store model in", "option", "o", Path),
meta_path=("Optional path to meta.json to use as base.", "option", "m", Path),
raw_text=("Path to jsonl file with unlabelled text documents.", "option", "rt", Path),
use_gpu=("Use GPU", "option", "g", int),
# fmt: on
)
def train_from_config_cli(
@ -130,6 +132,7 @@ def train_from_config_cli(
raw_text=None,
debug=False,
verbose=False,
use_gpu=-1
):
"""
Train or update a spaCy model. Requires data to be formatted in spaCy's
@ -147,6 +150,12 @@ def train_from_config_cli(
if output_path is not None and not output_path.exists():
output_path.mkdir()
if use_gpu >= 0:
msg.info("Using GPU")
util.use_gpu(use_gpu)
else:
msg.info("Using CPU")
train_from_config(
config_path,
{"train": train_path, "dev": dev_path},
@ -161,13 +170,8 @@ def train_from_config(
):
msg.info(f"Loading config from: {config_path}")
config = util.load_config(config_path, create_objects=False)
util.fix_random_seed(config["training"]["seed"])
nlp_config = config["nlp"]
use_gpu = config["training"]["use_gpu"]
if use_gpu >= 0:
msg.info("Using GPU")
util.use_gpu(use_gpu)
else:
msg.info("Using CPU")
config = util.load_config(config_path, create_objects=True)
msg.info("Creating nlp from config")
nlp = util.load_model_from_config(nlp_config)
@ -177,7 +181,7 @@ def train_from_config(
msg.info("Loading training corpus")
corpus = GoldCorpus(data_paths["train"], data_paths["dev"], limit=limit)
msg.info("Initializing the nlp pipeline")
nlp.begin_training(lambda: corpus.train_examples, device=use_gpu)
nlp.begin_training(lambda: corpus.train_examples)
train_batches = create_train_batches(nlp, corpus, training)
evaluate = create_evaluation_callback(nlp, optimizer, corpus, training)
@ -192,6 +196,7 @@ def train_from_config(
training["dropout"],
training["patience"],
training["eval_frequency"],
training["accumulate_gradient"]
)
msg.info(f"Training. Initial learn rate: {optimizer.learn_rate}")
@ -220,43 +225,50 @@ def train_from_config(
def create_train_batches(nlp, corpus, cfg):
while True:
train_examples = corpus.train_dataset(
train_examples = list(corpus.train_dataset(
nlp,
noise_level=0.0,
orth_variant_level=cfg["orth_variant_level"],
gold_preproc=cfg["gold_preproc"],
max_length=cfg["max_length"],
ignore_misaligned=True,
)
for batch in util.minibatch_by_words(train_examples, size=cfg["batch_size"]):
))
random.shuffle(train_examples)
batches = util.minibatch_by_words(train_examples, size=cfg["batch_size"])
for batch in batches:
yield batch
def create_evaluation_callback(nlp, optimizer, corpus, cfg):
def evaluate():
with nlp.use_params(optimizer.averages):
dev_examples = list(
corpus.dev_dataset(
nlp, gold_preproc=cfg["gold_preproc"], ignore_misaligned=True
)
dev_examples = list(
corpus.dev_dataset(
nlp, gold_preproc=cfg["gold_preproc"], ignore_misaligned=True
)
n_words = sum(len(ex.doc) for ex in dev_examples)
start_time = timer()
scorer = nlp.evaluate(dev_examples)
end_time = timer()
wps = n_words / (end_time - start_time)
scores = scorer.scores
# Calculate a weighted sum based on score_weights for the main score
weights = cfg["score_weights"]
weighted_score = sum(scores[s] * weights.get(s, 0.0) for s in weights)
scores["speed"] = wps
)
n_words = sum(len(ex.doc) for ex in dev_examples)
start_time = timer()
if optimizer.averages:
with nlp.use_params(optimizer.averages):
scorer = nlp.evaluate(dev_examples, batch_size=32)
else:
scorer = nlp.evaluate(dev_examples, batch_size=32)
end_time = timer()
wps = n_words / (end_time - start_time)
scores = scorer.scores
# Calculate a weighted sum based on score_weights for the main score
weights = cfg["score_weights"]
weighted_score = sum(scores[s] * weights.get(s, 0.0) for s in weights)
scores["speed"] = wps
return weighted_score, scores
return evaluate
def train_while_improving(
nlp, optimizer, train_data, evaluate, dropout, patience, eval_frequency
nlp, optimizer, train_data, evaluate, dropout, patience, eval_frequency,
accumulate_gradient
):
"""Train until an evaluation stops improving. Works as a generator,
with each iteration yielding a tuple `(batch, info, is_best_checkpoint)`,
@ -303,7 +315,7 @@ def train_while_improving(
losses = {}
for step, batch in enumerate(train_data):
dropout = next(dropouts)
for subbatch in subdivide_batch(batch):
for subbatch in subdivide_batch(batch, accumulate_gradient):
nlp.update(subbatch, drop=dropout, losses=losses, sgd=False)
for name, proc in nlp.pipeline:
if hasattr(proc, "model"):
@ -332,8 +344,19 @@ def train_while_improving(
break
def subdivide_batch(batch):
return [batch]
def subdivide_batch(batch, accumulate_gradient):
batch = list(batch)
batch.sort(key=lambda eg: len(eg.doc))
sub_len = len(batch) // accumulate_gradient
start = 0
for i in range(accumulate_gradient):
subbatch = batch[start : start + sub_len]
if subbatch:
yield subbatch
start += len(subbatch)
subbatch = batch[start : ]
if subbatch:
yield subbatch
def setup_printer(training, nlp):

View File

@ -608,6 +608,14 @@ def iob_to_biluo(tags):
return out
def biluo_to_iob(tags):
out = []
for tag in tags:
tag = tag.replace("U-", "B-", 1).replace("L-", "I-", 1)
out.append(tag)
return out
def _consume_os(tags):
while tags and tags[0] == "O":
yield tags.pop(0)

View File

@ -195,6 +195,7 @@ class Language(object):
default_senter_config,
default_tensorizer_config,
default_tok2vec_config,
default_simple_ner_config
)
self.defaults = {
@ -205,6 +206,7 @@ class Language(object):
"entity_linker": default_nel_config(),
"morphologizer": default_morphologizer_config(),
"senter": default_senter_config(),
"simple_ner": default_simple_ner_config(),
"tensorizer": default_tensorizer_config(),
"tok2vec": default_tok2vec_config(),
}

109
spacy/ml/_biluo.py Normal file
View File

@ -0,0 +1,109 @@
"""Thinc layer to do simpler transition-based parsing, NER, etc."""
from typing import List, Tuple, Dict, Optional
import numpy
from thinc.api import Ops, Model, with_array, softmax_activation, padded2list
from thinc.api import to_numpy
from thinc.types import Padded, Ints1d, Ints3d, Floats2d, Floats3d
from ..tokens import Doc
def BILUO() -> Model[Padded, Padded]:
return Model(
"biluo",
forward,
init=init,
dims={"nO": None},
attrs={"get_num_actions": get_num_actions}
)
def init(model, X: Optional[Padded]=None, Y: Optional[Padded]=None):
if X is not None and Y is not None:
if X.data.shape != Y.data.shape:
# TODO: Fix error
raise ValueError("Mismatched shapes (TODO: Fix message)")
model.set_dim("nO", X.data.shape[2])
elif X is not None:
model.set_dim("nO", X.data.shape[2])
elif Y is not None:
model.set_dim("nO", Y.data.shape[2])
elif model.get_dim("nO") is None:
raise ValueError("Dimension unset for BILUO: nO")
def forward(model: Model[Padded, Padded], Xp: Padded, is_train: bool):
n_labels = (model.get_dim("nO") - 1) // 4
n_tokens, n_docs, n_actions = Xp.data.shape
# At each timestep, we make a validity mask of shape (n_docs, n_actions)
# to indicate which actions are valid next for each sequence. To construct
# the mask, we have a state of shape (2, n_actions) and a validity table of
# shape (2, n_actions+1, n_actions). The first dimension of the state indicates
# whether it's the last token, the second dimension indicates the previous
# action, plus a special 'null action' for the first entry.
valid_transitions = model.ops.asarray(_get_transition_table(n_labels))
prev_actions = model.ops.alloc1i(n_docs)
# Initialize as though prev action was O
prev_actions.fill(n_actions - 1)
Y = model.ops.alloc3f(*Xp.data.shape)
masks = model.ops.alloc3f(*Y.shape)
max_value = Xp.data.max()
for t in range(Xp.data.shape[0]):
is_last = (Xp.lengths < (t+2)).astype("i")
masks[t] = valid_transitions[is_last, prev_actions]
# Don't train the out-of-bounds sequences.
masks[t, Xp.size_at_t[t]:] = 0
# Valid actions get 0*10e8, invalid get large negative value
Y[t] = Xp.data[t] + ((masks[t]-1) * max_value * 10)
prev_actions = Y[t].argmax(axis=-1)
def backprop_biluo(dY: Padded) -> Padded:
dY.data *= masks
return dY
return Padded(Y, Xp.size_at_t, Xp.lengths, Xp.indices), backprop_biluo
def get_num_actions(n_labels: int) -> int:
# One BEGIN action per label
# One IN action per label
# One LAST action per label
# One UNIT action per label
# One OUT action
return n_labels + n_labels + n_labels + n_labels + 1
def _get_transition_table(
n_labels: int, *, _cache: Dict[int, Floats3d] = {}
) -> Floats3d:
n_actions = get_num_actions(n_labels)
if n_actions in _cache:
return _cache[n_actions]
table = numpy.zeros((2, n_actions, n_actions), dtype="f")
B_start, B_end = (0, n_labels)
I_start, I_end = (B_end, B_end + n_labels)
L_start, L_end = (I_end, I_end + n_labels)
U_start, U_end = (L_end, L_end + n_labels)
# Using ranges allows us to set specific cells, which is necessary to express
# that only actions of the same label are valid continuations.
B_range = numpy.arange(B_start, B_end)
I_range = numpy.arange(I_start, I_end)
L_range = numpy.arange(L_start, L_end)
O_action = U_end
# If this is the last token and the previous action was B or I, only L
# of that label is valid
table[1, B_range, L_range] = 1
table[1, I_range, L_range] = 1
# If this isn't the last token and the previous action was B or I, only I or
# L of that label are valid.
table[0, B_range, I_range] = 1
table[0, B_range, L_range] = 1
table[0, I_range, I_range] = 1
table[0, I_range, L_range] = 1
# If this isn't the last token and the previous was L, U or O, B is valid
table[0, L_start:, :B_end] = 1
# Regardless of whether this is the last token, if the previous action was
# {L, U, O}, U and O are valid.
table[:, L_start:, U_start:] = 1
_cache[n_actions] = table
return table

92
spacy/ml/_iob.py Normal file
View File

@ -0,0 +1,92 @@
"""Thinc layer to do simpler transition-based parsing, NER, etc."""
from typing import List, Tuple, Dict, Optional
from thinc.api import Ops, Model, with_array, softmax_activation, padded2list
from thinc.types import Padded, Ints1d, Ints3d, Floats2d, Floats3d
from ..tokens import Doc
def IOB() -> Model[Padded, Padded]:
return Model(
"biluo",
forward,
init=init,
dims={"nO": None},
attrs={"get_num_actions": get_num_actions}
)
def init(model, X: Optional[Padded]=None, Y: Optional[Padded]=None):
if X is not None and Y is not None:
if X.data.shape != Y.data.shape:
# TODO: Fix error
raise ValueError("Mismatched shapes (TODO: Fix message)")
model.set_dim("nO", X.data.shape[2])
elif X is not None:
model.set_dim("nO", X.data.shape[2])
elif Y is not None:
model.set_dim("nO", Y.data.shape[2])
elif model.get_dim("nO") is None:
raise ValueError("Dimension unset for BILUO: nO")
def forward(model: Model[Padded, Padded], Xp: Padded, is_train: bool):
n_labels = (model.get_dim("nO") - 1) // 2
n_tokens, n_docs, n_actions = Xp.data.shape
# At each timestep, we make a validity mask of shape (n_docs, n_actions)
# to indicate which actions are valid next for each sequence. To construct
# the mask, we have a state of shape (2, n_actions) and a validity table of
# shape (2, n_actions+1, n_actions). The first dimension of the state indicates
# whether it's the last token, the second dimension indicates the previous
# action, plus a special 'null action' for the first entry.
valid_transitions = _get_transition_table(model.ops, n_labels)
prev_actions = model.ops.alloc1i(n_docs)
# Initialize as though prev action was O
prev_actions.fill(n_actions - 1)
Y = model.ops.alloc3f(*Xp.data.shape)
masks = model.ops.alloc3f(*Y.shape)
for t in range(Xp.data.shape[0]):
masks[t] = valid_transitions[prev_actions]
# Don't train the out-of-bounds sequences.
masks[t, Xp.size_at_t[t]:] = 0
# Valid actions get 0*10e8, invalid get -1*10e8
Y[t] = Xp.data[t] + ((masks[t]-1) * 10e8)
prev_actions = Y[t].argmax(axis=-1)
def backprop_biluo(dY: Padded) -> Padded:
# Masking the gradient seems to do poorly here. But why?
#dY.data *= masks
return dY
return Padded(Y, Xp.size_at_t, Xp.lengths, Xp.indices), backprop_biluo
def get_num_actions(n_labels: int) -> int:
# One BEGIN action per label
# One IN action per label
# One LAST action per label
# One UNIT action per label
# One OUT action
return n_labels * 2 + 1
def _get_transition_table(
ops: Ops, n_labels: int, _cache: Dict[int, Floats3d] = {}
) -> Floats3d:
n_actions = get_num_actions(n_labels)
if n_actions in _cache:
return ops.asarray(_cache[n_actions])
table = ops.alloc2f(n_actions, n_actions)
B_start, B_end = (0, n_labels)
I_start, I_end = (B_end, B_end + n_labels)
O_action = I_end
B_range = ops.xp.arange(B_start, B_end)
I_range = ops.xp.arange(I_start, I_end)
# B and O are always valid
table[:, B_start : B_end] = 1
table[:, O_action] = 1
# I can only follow a matching B
table[B_range, I_range] = 1
_cache[n_actions] = table
return table

View File

@ -9,7 +9,6 @@ def PrecomputableAffine(nO, nI, nF, nP):
dims={"nO": nO, "nI": nI, "nF": nF, "nP": nP},
params={"W": None, "b": None, "pad": None},
)
model.initialize()
return model
@ -110,8 +109,7 @@ def init(model, X=None, Y=None):
pad = model.ops.alloc4f(1, nF, nO, nP)
ops = model.ops
scale = float(ops.xp.sqrt(1.0 / (nF * nI)))
W = normal_init(ops, W.shape, mean=scale)
W = normal_init(ops, W.shape, mean=float(ops.xp.sqrt(1.0 / nF * nI)))
model.set_param("W", W)
model.set_param("b", b)
model.set_param("pad", pad)

View File

@ -1,5 +1,6 @@
from .entity_linker import * # noqa
from .parser import * # noqa
from .simple_ner import *
from .tagger import * # noqa
from .tensorizer import * # noqa
from .textcat import * # noqa

View File

@ -91,3 +91,13 @@ def default_tok2vec_config():
def default_tok2vec():
loc = Path(__file__).parent / "tok2vec_defaults.cfg"
return util.load_config(loc, create_objects=True)["model"]
def default_simple_ner_config():
loc = Path(__file__).parent / "simple_ner_defaults.cfg"
return util.load_config(loc, create_objects=False)
def default_simple_ner():
loc = Path(__file__).parent / "simple_ner_defaults.cfg"
return util.load_config(loc, create_objects=True)["model"]

View File

@ -0,0 +1,12 @@
[model]
@architectures = "spacy.BiluoTagger.v1"
[model.tok2vec]
@architectures = "spacy.HashEmbedCNN.v1"
pretrained_vectors = null
width = 128
depth = 4
embed_size = 7000
window_size = 1
maxout_pieces = 3
subword_features = true

View File

@ -1,9 +1,9 @@
from pydantic import StrictInt
from thinc.api import Model, chain, list2array, Linear, zero_init, use_ops
from thinc.api import Model, chain, list2array, Linear, zero_init, use_ops, with_array
from ...util import registry
from .._precomputable_affine import PrecomputableAffine
from ...syntax._parser_model import ParserModel
from ..tb_framework import TransitionModel
@registry.architectures.register("spacy.TransitionBasedParser.v1")
@ -12,21 +12,27 @@ def build_tb_parser_model(
nr_feature_tokens: StrictInt,
hidden_width: StrictInt,
maxout_pieces: StrictInt,
use_upper=True,
nO=None,
):
token_vector_width = tok2vec.get_dim("nO")
tok2vec = chain(tok2vec, list2array())
tok2vec.set_dim("nO", token_vector_width)
tok2vec = chain(
tok2vec,
with_array(Linear(hidden_width, token_vector_width)),
list2array(),
)
tok2vec.set_dim("nO", hidden_width)
lower = PrecomputableAffine(
nO=hidden_width,
nO=hidden_width if use_upper else nO,
nF=nr_feature_tokens,
nI=tok2vec.get_dim("nO"),
nP=maxout_pieces,
nP=maxout_pieces
)
lower.set_dim("nP", maxout_pieces)
with use_ops("numpy"):
# Initialize weights at zero, as it's a classification layer.
upper = Linear(nO=nO, init_W=zero_init)
model = ParserModel(tok2vec, lower, upper)
return model
if use_upper:
with use_ops("numpy"):
# Initialize weights at zero, as it's a classification layer.
upper = Linear(nO=nO, init_W=zero_init)
else:
upper = None
return TransitionModel(tok2vec, lower, upper)

View File

@ -0,0 +1,82 @@
import functools
from typing import List, Tuple, Dict, Optional
from thinc.api import Ops, Model, Linear, Softmax, with_array, softmax_activation, padded2list
from thinc.api import chain, list2padded, configure_normal_init
from thinc.api import Dropout
from thinc.types import Padded, Ints1d, Ints3d, Floats2d, Floats3d
from ...tokens import Doc
from .._biluo import BILUO
from .._iob import IOB
from ...util import registry
@registry.architectures.register("spacy.BiluoTagger.v1")
def BiluoTagger(tok2vec: Model[List[Doc], List[Floats2d]]) -> Model[List[Doc], List[Floats2d]]:
biluo = BILUO()
linear = Linear(
nO=None,
nI=tok2vec.get_dim("nO"),
init_W=configure_normal_init(mean=0.02)
)
model = chain(
tok2vec,
list2padded(),
with_array(chain(Dropout(0.1), linear)),
biluo,
with_array(softmax_activation()),
padded2list()
)
return Model(
"biluo-tagger",
forward,
init=init,
layers=[model, linear],
refs={"tok2vec": tok2vec, "linear": linear, "biluo": biluo},
dims={"nO": None},
attrs={"get_num_actions": biluo.attrs["get_num_actions"]}
)
@registry.architectures.register("spacy.IOBTagger.v1")
def IOBTagger(tok2vec: Model[List[Doc], List[Floats2d]]) -> Model[List[Doc], List[Floats2d]]:
biluo = IOB()
linear = Linear(nO=None, nI=tok2vec.get_dim("nO"))
model = chain(
tok2vec,
list2padded(),
with_array(linear),
biluo,
with_array(softmax_activation()),
padded2list()
)
return Model(
"iob-tagger",
forward,
init=init,
layers=[model],
refs={"tok2vec": tok2vec, "linear": linear, "biluo": biluo},
dims={"nO": None},
attrs={"get_num_actions": biluo.attrs["get_num_actions"]}
)
def init(model: Model[List[Doc], List[Floats2d]], X=None, Y=None) -> None:
if model.get_dim("nO") is None and Y:
model.set_dim("nO", Y[0].shape[1])
nO = model.get_dim("nO")
biluo = model.get_ref("biluo")
linear = model.get_ref("linear")
biluo.set_dim("nO", nO)
if linear.has_dim("nO") is None:
linear.set_dim("nO", nO)
model.layers[0].initialize(X=X, Y=Y)
def forward(model: Model, X: List[Doc], is_train: bool):
return model.layers[0](X, is_train)
__all__ = ["BiluoTagger"]

View File

@ -1,4 +1,5 @@
from thinc.api import zero_init, with_array, Softmax, chain, Model
from thinc.api import zero_init, with_array, Softmax, chain, Model, Dropout
from thinc.api import glorot_uniform_init
from ...util import registry
@ -11,6 +12,6 @@ def build_tagger_model(tok2vec, nO=None) -> Model:
softmax = with_array(output_layer)
model = chain(tok2vec, softmax)
model.set_ref("tok2vec", tok2vec)
model.set_ref("softmax", softmax)
model.set_ref("softmax", output_layer)
model.set_ref("output_layer", output_layer)
return model

86
spacy/ml/tb_framework.py Normal file
View File

@ -0,0 +1,86 @@
from thinc.api import Model, noop, use_ops, Linear
from ..syntax._parser_model import ParserStepModel
def TransitionModel(tok2vec, lower, upper, unseen_classes=set()):
"""Set up a stepwise transition-based model"""
if upper is None:
has_upper = False
upper = noop()
else:
has_upper = True
# don't define nO for this object, because we can't dynamically change it
return Model(
name="parser_model",
forward=forward,
dims={"nI": tok2vec.get_dim("nI") if tok2vec.has_dim("nI") else None},
layers=[tok2vec, lower, upper],
refs={"tok2vec": tok2vec, "lower": lower, "upper": upper},
init=init,
attrs={
"has_upper": has_upper,
"unseen_classes": set(unseen_classes),
"resize_output": resize_output
}
)
def forward(model, X, is_train):
step_model = ParserStepModel(
X,
model.layers,
unseen_classes=model.attrs["unseen_classes"],
train=is_train,
has_upper=model.attrs["has_upper"]
)
return step_model, step_model.finish_steps
def init(model, X=None, Y=None):
tok2vec = model.get_ref("tok2vec").initialize()
lower = model.get_ref("lower").initialize(X=X)
if model.attrs["has_upper"]:
statevecs = model.ops.alloc2f(2, lower.get_dim("nO"))
model.get_ref("upper").initialize(X=statevecs)
def resize_output(model, new_nO):
tok2vec = model.get_ref("tok2vec")
lower = model.get_ref("lower")
upper = model.get_ref("upper")
if not model.attrs["has_upper"]:
if lower.has_dim("nO") is None:
lower.set_dim("nO", new_nO)
return
elif upper.has_dim("nO") is None:
upper.set_dim("nO", new_nO)
return
elif new_nO == upper.get_dim("nO"):
return
smaller = upper
nI = None
if smaller.has_dim("nI"):
nI = smaller.get_dim("nI")
with use_ops('numpy'):
larger = Linear(nO=new_nO, nI=nI)
larger.init = smaller.init
# it could be that the model is not initialized yet, then skip this bit
if nI:
larger_W = larger.ops.alloc2f(new_nO, nI)
larger_b = larger.ops.alloc1f(new_nO)
smaller_W = smaller.get_param("W")
smaller_b = smaller.get_param("b")
# Weights are stored in (nr_out, nr_in) format, so we're basically
# just adding rows here.
if smaller.has_dim("nO"):
larger_W[:smaller.get_dim("nO")] = smaller_W
larger_b[:smaller.get_dim("nO")] = smaller_b
for i in range(smaller.get_dim("nO"), new_nO):
model.attrs["unseen_classes"].add(i)
larger.set_param("W", larger_W)
larger.set_param("b", larger_b)
model._layers[-1] = larger
model.set_ref("upper", larger)
return model

View File

@ -1,6 +1,7 @@
from .pipes import Tagger, DependencyParser, EntityRecognizer, EntityLinker
from .pipes import TextCategorizer, Tensorizer, Pipe, Sentencizer
from .pipes import SentenceRecognizer
from .simple_ner import SimpleNER
from .morphologizer import Morphologizer
from .entityruler import EntityRuler
from .tok2vec import Tok2Vec
@ -22,6 +23,7 @@ __all__ = [
"SentenceSegmenter",
"SentenceRecognizer",
"SimilarityHook",
"SimpleNER",
"merge_entities",
"merge_noun_chunks",
"merge_subtokens",

View File

@ -3,7 +3,7 @@ import numpy
import srsly
import random
from thinc.api import CosineDistance, to_categorical, get_array_module
from thinc.api import set_dropout_rate
from thinc.api import set_dropout_rate, SequenceCategoricalCrossentropy
import warnings
from ..tokens.doc cimport Doc
@ -464,6 +464,9 @@ class Tagger(Pipe):
return
set_dropout_rate(self.model, drop)
tag_scores, bp_tag_scores = self.model.begin_update([ex.doc for ex in examples])
for sc in tag_scores:
if self.model.ops.xp.isnan(sc.sum()):
raise ValueError("nan value in scores")
loss, d_tag_scores = self.get_loss(examples, tag_scores)
bp_tag_scores(d_tag_scores)
if sgd not in (None, False):
@ -497,29 +500,11 @@ class Tagger(Pipe):
losses[self.name] += (gradient**2).sum()
def get_loss(self, examples, scores):
scores = self.model.ops.flatten(scores)
tag_index = {tag: i for i, tag in enumerate(self.labels)}
cdef int idx = 0
correct = numpy.zeros((scores.shape[0],), dtype="i")
guesses = scores.argmax(axis=1)
known_labels = numpy.ones((scores.shape[0], 1), dtype="f")
for ex in examples:
gold = ex.gold
for tag in gold.tags:
if tag is None:
correct[idx] = guesses[idx]
elif tag in tag_index:
correct[idx] = tag_index[tag]
else:
correct[idx] = 0
known_labels[idx] = 0.
idx += 1
correct = self.model.ops.xp.array(correct, dtype="i")
d_scores = scores - to_categorical(correct, n_classes=scores.shape[1])
d_scores *= self.model.ops.asarray(known_labels)
loss = (d_scores**2).sum()
docs = [ex.doc for ex in examples]
d_scores = self.model.ops.unflatten(d_scores, [len(d) for d in docs])
loss_func = SequenceCategoricalCrossentropy(names=self.labels)
truths = [eg.gold.tags for eg in examples]
d_scores, loss = loss_func(scores, truths)
if self.model.ops.xp.isnan(loss):
raise ValueError("nan value when computing loss")
return float(loss), d_scores
def begin_training(self, get_examples=lambda: [], pipeline=None, sgd=None,

View File

@ -0,0 +1,149 @@
from typing import List
from thinc.types import Floats2d
from thinc.api import SequenceCategoricalCrossentropy, set_dropout_rate
from thinc.util import to_numpy
from ..gold import Example, spans_from_biluo_tags, iob_to_biluo, biluo_to_iob
from ..tokens import Doc
from ..language import component
from ..util import link_vectors_to_models
from .pipes import Pipe
@component("simple_ner", assigns=["doc.ents"])
class SimpleNER(Pipe):
"""Named entity recognition with a tagging model. The model should include
validity constraints to ensure that only valid tag sequences are returned."""
def __init__(self, vocab, model):
self.vocab = vocab
self.model = model
self.cfg = {"labels": []}
self.loss_func = SequenceCategoricalCrossentropy(
names=self.get_tag_names(),
normalize=True,
missing_value=None
)
assert self.model is not None
@property
def labels(self):
return self.cfg["labels"]
@property
def is_biluo(self):
return self.model.name.startswith("biluo")
def add_label(self, label):
if label not in self.cfg["labels"]:
self.cfg["labels"].append(label)
def get_tag_names(self):
if self.is_biluo:
return (
[f"B-{label}" for label in self.labels] +
[f"I-{label}" for label in self.labels] +
[f"L-{label}" for label in self.labels] +
[f"U-{label}" for label in self.labels] +
["O"]
)
else:
return (
[f"B-{label}" for label in self.labels] +
[f"I-{label}" for label in self.labels] +
["O"]
)
def predict(self, docs: List[Doc]) -> List[Floats2d]:
scores = self.model.predict(docs)
return scores
def set_annotations(self, docs: List[Doc], scores: List[Floats2d], tensors=None):
"""Set entities on a batch of documents from a batch of scores."""
tag_names = self.get_tag_names()
for i, doc in enumerate(docs):
actions = to_numpy(scores[i].argmax(axis=1))
tags = [tag_names[actions[j]] for j in range(len(doc))]
if not self.is_biluo:
tags = iob_to_biluo(tags)
doc.ents = spans_from_biluo_tags(doc, tags)
def update(self, examples, set_annotations=False, drop=0.0, sgd=None, losses=None):
if not any(_has_ner(eg) for eg in examples):
return 0
examples = Example.to_example_objects(examples)
docs = [ex.doc for ex in examples]
set_dropout_rate(self.model, drop)
scores, bp_scores = self.model.begin_update(docs)
loss, d_scores = self.get_loss(examples, scores)
bp_scores(d_scores)
if set_annotations:
self.set_annotations(docs, scores)
if sgd is not None:
self.model.finish_update(sgd)
if losses is not None:
losses.setdefault("ner", 0.0)
losses["ner"] += loss
return loss
def get_loss(self, examples, scores):
loss = 0
d_scores = []
truths = []
for eg in examples:
gold_tags = [(tag if tag != "-" else None) for tag in eg.gold.ner]
if not self.is_biluo:
gold_tags = biluo_to_iob(gold_tags)
truths.append(gold_tags)
for i in range(len(scores)):
if len(scores[i]) != len(truths[i]):
raise ValueError(
f"Mismatched output and gold sizes.\n"
f"Output: {len(scores[i])}, gold: {len(truths[i])}."
f"Input: {len(examples[i].doc)}"
)
d_scores, loss = self.loss_func(scores, truths)
return loss, d_scores
def begin_training(self, get_examples, pipeline=None, sgd=None, **kwargs):
self.cfg.update(kwargs)
if not hasattr(get_examples, '__call__'):
gold_tuples = get_examples
get_examples = lambda: gold_tuples
labels = _get_labels(get_examples())
for label in _get_labels(get_examples()):
self.add_label(label)
labels = self.labels
n_actions = self.model.attrs["get_num_actions"](len(labels))
self.model.set_dim("nO", n_actions)
self.model.initialize()
if pipeline is not None:
self.init_multitask_objectives(get_examples, pipeline, sgd=sgd, **self.cfg)
link_vectors_to_models(self.vocab)
self.loss_func = SequenceCategoricalCrossentropy(
names=self.get_tag_names(),
normalize=True,
missing_value=None
)
return sgd
def init_multitask_objectives(self, *args, **kwargs):
pass
def _has_ner(eg):
for ner_tag in eg.gold.ner:
if ner_tag != "-" and ner_tag != None:
return True
else:
return False
def _get_labels(examples):
labels = set()
for eg in examples:
for ner_tag in eg.token_annotation.entities:
if ner_tag != 'O' and ner_tag != '-':
_, label = ner_tag.split('-', 1)
labels.add(label)
return list(sorted(labels))

View File

@ -12,7 +12,7 @@ cimport blis.cy
import numpy
import numpy.random
from thinc.api import Linear, Model, CupyOps, NumpyOps, use_ops
from thinc.api import Linear, Model, CupyOps, NumpyOps, use_ops, noop
from ..typedefs cimport weight_t, class_t, hash_t
from ..tokens.doc cimport Doc
@ -219,112 +219,27 @@ cdef int arg_max_if_valid(const weight_t* scores, const int* is_valid, int n) no
return best
class ParserModel(Model):
def __init__(self, tok2vec, lower_model, upper_model, unseen_classes=None):
# don't define nO for this object, because we can't dynamically change it
Model.__init__(self, name="parser_model", forward=forward, dims={"nI": None})
if tok2vec.has_dim("nI"):
self.set_dim("nI", tok2vec.get_dim("nI"))
self._layers = [tok2vec, lower_model]
if upper_model is not None:
self._layers.append(upper_model)
self.unseen_classes = set()
if unseen_classes:
for class_ in unseen_classes:
self.unseen_classes.add(class_)
self.set_ref("tok2vec", tok2vec)
def predict(self, docs):
step_model = ParserStepModel(docs, self._layers,
unseen_classes=self.unseen_classes, train=False)
return step_model
def resize_output(self, new_nO):
if len(self._layers) == 2:
return
if self.upper.has_dim("nO") and (new_nO == self.upper.get_dim("nO")):
return
smaller = self.upper
nI = None
if smaller.has_dim("nI"):
nI = smaller.get_dim("nI")
with use_ops('numpy'):
larger = Linear(nO=new_nO, nI=nI)
larger.init = smaller.init
# it could be that the model is not initialized yet, then skip this bit
if nI:
larger_W = larger.ops.alloc2f(new_nO, nI)
larger_b = larger.ops.alloc1f(new_nO)
smaller_W = smaller.get_param("W")
smaller_b = smaller.get_param("b")
# Weights are stored in (nr_out, nr_in) format, so we're basically
# just adding rows here.
if smaller.has_dim("nO"):
larger_W[:smaller.get_dim("nO")] = smaller_W
larger_b[:smaller.get_dim("nO")] = smaller_b
for i in range(smaller.get_dim("nO"), new_nO):
self.unseen_classes.add(i)
larger.set_param("W", larger_W)
larger.set_param("b", larger_b)
self._layers[-1] = larger
def initialize(self, X=None, Y=None):
self.tok2vec.initialize()
self.lower.initialize(X=X, Y=Y)
if self.upper is not None:
# In case we need to trigger the callbacks
statevecs = self.ops.alloc((2, self.lower.get_dim("nO")))
self.upper.initialize(X=statevecs)
def finish_update(self, optimizer):
self.tok2vec.finish_update(optimizer)
self.lower.finish_update(optimizer)
if self.upper is not None:
self.upper.finish_update(optimizer)
@property
def tok2vec(self):
return self._layers[0]
@property
def lower(self):
return self._layers[1]
@property
def upper(self):
return self._layers[2]
def forward(model:ParserModel, X, is_train):
step_model = ParserStepModel(X, model._layers, unseen_classes=model.unseen_classes,
train=is_train)
return step_model, step_model.finish_steps
class ParserStepModel(Model):
def __init__(self, docs, layers, unseen_classes=None, train=True):
def __init__(self, docs, layers, *, has_upper, unseen_classes=None, train=True):
Model.__init__(self, name="parser_step_model", forward=step_forward)
self.attrs["has_upper"] = has_upper
self.tokvecs, self.bp_tokvecs = layers[0](docs, is_train=train)
if layers[1].get_dim("nP") >= 2:
activation = "maxout"
elif len(layers) == 2:
elif has_upper:
activation = None
else:
activation = "relu"
self.state2vec = precompute_hiddens(len(docs), self.tokvecs, layers[1],
activation=activation, train=train)
if len(layers) == 3:
if has_upper:
self.vec2scores = layers[-1]
else:
self.vec2scores = None
self.cuda_stream = util.get_cuda_stream(non_blocking=True)
self.backprops = []
if self.vec2scores is None:
self._class_mask = numpy.zeros((self.state2vec.nO,), dtype='f')
else:
self._class_mask = numpy.zeros((self.vec2scores.get_dim("nO"),), dtype='f')
self._class_mask = numpy.zeros((self.nO,), dtype='f')
self._class_mask.fill(1)
if unseen_classes is not None:
for class_ in unseen_classes:
@ -332,7 +247,10 @@ class ParserStepModel(Model):
@property
def nO(self):
return self.state2vec.nO
if self.attrs["has_upper"]:
return self.vec2scores.get_dim("nO")
else:
return self.state2vec.get_dim("nO")
def class_is_unseen(self, class_):
return self._class_mask[class_]
@ -378,7 +296,7 @@ class ParserStepModel(Model):
def step_forward(model: ParserStepModel, states, is_train):
token_ids = model.get_token_ids(states)
vector, get_d_tokvecs = model.state2vec(token_ids, is_train)
if model.vec2scores is not None:
if model.attrs["has_upper"]:
scores, get_d_vector = model.vec2scores(vector, is_train)
else:
scores = NumpyOps().asarray(vector)

View File

@ -36,7 +36,6 @@ from ..util import link_vectors_to_models, create_default_optimizer, registry
from ..compat import copy_array
from ..errors import Errors, Warnings
from .. import util
from ._parser_model import ParserModel
from . import _beam_utils
from . import nonproj
@ -69,7 +68,8 @@ cdef class Parser:
cfg.setdefault('beam_width', 1)
cfg.setdefault('beam_update_prob', 1.0) # or 0.5 (both defaults were previously used)
self.model = model
self.set_output(self.moves.n_moves)
if self.moves.n_moves != 0:
self.set_output(self.moves.n_moves)
self.cfg = cfg
self._multitasks = []
self._rehearsal_model = None
@ -105,7 +105,7 @@ cdef class Parser:
@property
def tok2vec(self):
'''Return the embedding and convolutional layer of the model.'''
return self.model.tok2vec
return self.model.get_ref("tok2vec")
@property
def postprocesses(self):
@ -122,9 +122,11 @@ cdef class Parser:
self._resize()
def _resize(self):
self.model.resize_output(self.moves.n_moves)
self.model.attrs["resize_output"](self.model, self.moves.n_moves)
if self._rehearsal_model not in (True, False, None):
self._rehearsal_model.resize_output(self.moves.n_moves)
self._rehearsal_model.attrs["resize_output"](
self._rehearsal_model, self.moves.n_moves
)
def add_multitask_objective(self, target):
# Defined in subclasses, to avoid circular import
@ -216,7 +218,6 @@ cdef class Parser:
# expand our model output.
self._resize()
model = self.model.predict(docs)
W_param = model.vec2scores.get_param("W")
weights = get_c_weights(model)
for state in batch:
if not state.is_final():
@ -237,7 +238,7 @@ cdef class Parser:
# if labels are missing. We therefore have to check whether we need to
# expand our model output.
self._resize()
cdef int nr_feature = self.model.lower.get_dim("nF")
cdef int nr_feature = self.model.get_ref("lower").get_dim("nF")
model = self.model.predict(docs)
token_ids = numpy.zeros((len(docs) * beam_width, nr_feature),
dtype='i', order='C')
@ -370,13 +371,16 @@ cdef class Parser:
beam_density=self.cfg.get('beam_density', 0.001))
set_dropout_rate(self.model, drop)
# Chop sequences into lengths of this many transitions, to make the
# batch uniform length.
cut_gold = numpy.random.choice(range(20, 100))
states, golds, max_steps = self._init_gold_batch(examples, max_length=cut_gold)
cut_gold = True
if cut_gold:
# Chop sequences into lengths of this many transitions, to make the
# batch uniform length.
cut_gold = numpy.random.choice(range(20, 100))
states, golds, max_steps = self._init_gold_batch(examples, max_length=cut_gold)
else:
states, golds, max_steps = self._init_gold_batch_no_cut(examples)
states_golds = [(s, g) for (s, g) in zip(states, golds)
if not s.is_final() and g is not None]
# Prepare the stepwise model, and get the callback for finishing the batch
model, backprop_tok2vec = self.model.begin_update([ex.doc for ex in examples])
all_states = list(states)
@ -456,9 +460,17 @@ cdef class Parser:
set_dropout_rate(self.model, drop)
model, backprop_tok2vec = self.model.begin_update(docs)
states_d_scores, backprops, beams = _beam_utils.update_beam(
self.moves, self.model.lower.get_dim("nF"), 10000, states, golds,
model.state2vec, model.vec2scores, width, losses=losses,
beam_density=beam_density)
self.moves,
self.model.get_ref("lower").get_dim("nF"),
10000,
states,
golds,
model.state2vec,
model.vec2scores,
width,
losses=losses,
beam_density=beam_density
)
for i, d_scores in enumerate(states_d_scores):
losses[self.name] += (d_scores**2).mean()
ids, bp_vectors, bp_scores = backprops[i]
@ -497,6 +509,24 @@ cdef class Parser:
queue.extend(node._layers)
return gradients
def _init_gold_batch_no_cut(self, whole_examples):
states = self.moves.init_batch([eg.doc for eg in whole_examples])
good_docs = []
good_golds = []
good_states = []
for i, eg in enumerate(whole_examples):
doc = eg.doc
gold = self.moves.preprocess_gold(eg.gold)
if gold is not None and self.moves.has_gold(gold):
good_docs.append(doc)
good_golds.append(gold)
good_states.append(states[i])
n_moves = []
for doc, gold in zip(good_docs, good_golds):
oracle_actions = self.moves.get_oracle_sequence(doc, gold)
n_moves.append(len(oracle_actions))
return good_states, good_golds, max(n_moves, default=0) * 2
def _init_gold_batch(self, whole_examples, min_length=5, max_length=500):
"""Make a square batch, of length equal to the shortest doc. A long
doc will get multiple states. Let's say we have a doc of length 2*N,
@ -550,16 +580,19 @@ cdef class Parser:
cdef np.ndarray d_scores = numpy.zeros((len(states), self.moves.n_moves),
dtype='f', order='C')
c_d_scores = <float*>d_scores.data
unseen_classes = self.model.attrs["unseen_classes"]
for i, (state, gold) in enumerate(zip(states, golds)):
memset(is_valid, 0, self.moves.n_moves * sizeof(int))
memset(costs, 0, self.moves.n_moves * sizeof(float))
self.moves.set_costs(is_valid, costs, state, gold)
for j in range(self.moves.n_moves):
if costs[j] <= 0.0 and j in self.model.unseen_classes:
self.model.unseen_classes.remove(j)
if costs[j] <= 0.0 and j in unseen_classes:
unseen_classes.remove(j)
cpu_log_loss(c_d_scores,
costs, is_valid, &scores[i, 0], d_scores.shape[1])
c_d_scores += d_scores.shape[1]
if len(states):
d_scores /= len(states)
if losses is not None:
losses.setdefault(self.name, 0.)
losses[self.name] += (d_scores**2).sum()
@ -569,8 +602,7 @@ cdef class Parser:
return create_default_optimizer()
def set_output(self, nO):
if self.model.upper.has_dim("nO") is None:
self.model.upper.set_dim("nO", nO)
self.model.attrs["resize_output"](self.model, nO)
def begin_training(self, get_examples, pipeline=None, sgd=None, **kwargs):
self.cfg.update(kwargs)
@ -597,7 +629,6 @@ cdef class Parser:
for doc, gold in parses:
doc_sample.append(doc)
gold_sample.append(gold)
self.model.initialize(doc_sample, gold_sample)
if pipeline is not None:
self.init_multitask_objectives(get_examples, pipeline, sgd=sgd, **self.cfg)

View File

@ -65,7 +65,7 @@ def test_add_label_deserializes_correctly():
ner2 = EntityRecognizer(Vocab(), default_ner())
# the second model needs to be resized before we can call from_bytes
ner2.model.resize_output(ner1.moves.n_moves)
ner2.model.attrs["resize_output"](ner2.model, ner1.moves.n_moves)
ner2.from_bytes(ner1.to_bytes())
assert ner1.moves.n_moves == ner2.moves.n_moves
for i in range(ner1.moves.n_moves):

View File

@ -3,9 +3,9 @@ from spacy.ml.models.defaults import default_parser, default_tok2vec
from spacy.vocab import Vocab
from spacy.syntax.arc_eager import ArcEager
from spacy.syntax.nn_parser import Parser
from spacy.syntax._parser_model import ParserModel
from spacy.tokens.doc import Doc
from spacy.gold import GoldParse
from thinc.api import Model
@pytest.fixture
@ -34,7 +34,7 @@ def parser(vocab, arc_eager):
@pytest.fixture
def model(arc_eager, tok2vec, vocab):
model = default_parser()
model.resize_output(arc_eager.n_moves)
model.attrs["resize_output"](model, arc_eager.n_moves)
model.initialize()
return model
@ -50,7 +50,7 @@ def gold(doc):
def test_can_init_nn_parser(parser):
assert isinstance(parser.model, ParserModel)
assert isinstance(parser.model, Model)
def test_build_model(parser, vocab):

View File

@ -0,0 +1,417 @@
import pytest
from collections import namedtuple
from thinc.api import NumpyOps
from spacy.ml._biluo import BILUO, _get_transition_table
from spacy.pipeline.simple_ner import SimpleNER
import spacy
@pytest.fixture(params=[
["PER", "ORG", "LOC", "MISC"],
["GPE", "PERSON", "NUMBER", "CURRENCY", "EVENT"]
])
def labels(request):
return request.param
@pytest.fixture
def ops():
return NumpyOps()
def _get_actions(labels):
action_names = (
[f"B{label}" for label in labels] + \
[f"I{label}" for label in labels] + \
[f"L{label}" for label in labels] + \
[f"U{label}" for label in labels] + \
["O"]
)
A = namedtuple("actions", action_names)
return A(**{name: i for i, name in enumerate(action_names)})
def test_init_biluo_layer(labels):
model = BILUO()
model.set_dim("nO", model.attrs["get_num_actions"](len(labels)))
model.initialize()
assert model.get_dim("nO") == len(labels) * 4 + 1
def test_transition_table(ops):
labels = ["per", "loc", "org"]
table = _get_transition_table(len(labels))
a = _get_actions(labels)
assert table.shape == (2, len(a), len(a))
# Not last token, prev action was B
assert table[0, a.Bper, a.Bper] == 0
assert table[0, a.Bper, a.Bloc] == 0
assert table[0, a.Bper, a.Borg] == 0
assert table[0, a.Bper, a.Iper] == 1
assert table[0, a.Bper, a.Iloc] == 0
assert table[0, a.Bper, a.Iorg] == 0
assert table[0, a.Bper, a.Lper] == 1
assert table[0, a.Bper, a.Lloc] == 0
assert table[0, a.Bper, a.Lorg] == 0
assert table[0, a.Bper, a.Uper] == 0
assert table[0, a.Bper, a.Uloc] == 0
assert table[0, a.Bper, a.Uorg] == 0
assert table[0, a.Bper, a.O] == 0
assert table[0, a.Bloc, a.Bper] == 0
assert table[0, a.Bloc, a.Bloc] == 0
assert table[0, a.Bloc, a.Borg] == 0
assert table[0, a.Bloc, a.Iper] == 0
assert table[0, a.Bloc, a.Iloc] == 1
assert table[0, a.Bloc, a.Iorg] == 0
assert table[0, a.Bloc, a.Lper] == 0
assert table[0, a.Bloc, a.Lloc] == 1
assert table[0, a.Bloc, a.Lorg] == 0
assert table[0, a.Bloc, a.Uper] == 0
assert table[0, a.Bloc, a.Uloc] == 0
assert table[0, a.Bloc, a.Uorg] == 0
assert table[0, a.Bloc, a.O] == 0
assert table[0, a.Borg, a.Bper] == 0
assert table[0, a.Borg, a.Bloc] == 0
assert table[0, a.Borg, a.Borg] == 0
assert table[0, a.Borg, a.Iper] == 0
assert table[0, a.Borg, a.Iloc] == 0
assert table[0, a.Borg, a.Iorg] == 1
assert table[0, a.Borg, a.Lper] == 0
assert table[0, a.Borg, a.Lloc] == 0
assert table[0, a.Borg, a.Lorg] == 1
assert table[0, a.Borg, a.Uper] == 0
assert table[0, a.Borg, a.Uloc] == 0
assert table[0, a.Borg, a.Uorg] == 0
assert table[0, a.Borg, a.O] == 0
# Not last token, prev action was I
assert table[0, a.Iper, a.Bper] == 0
assert table[0, a.Iper, a.Bloc] == 0
assert table[0, a.Iper, a.Borg] == 0
assert table[0, a.Iper, a.Iper] == 1
assert table[0, a.Iper, a.Iloc] == 0
assert table[0, a.Iper, a.Iorg] == 0
assert table[0, a.Iper, a.Lper] == 1
assert table[0, a.Iper, a.Lloc] == 0
assert table[0, a.Iper, a.Lorg] == 0
assert table[0, a.Iper, a.Uper] == 0
assert table[0, a.Iper, a.Uloc] == 0
assert table[0, a.Iper, a.Uorg] == 0
assert table[0, a.Iper, a.O] == 0
assert table[0, a.Iloc, a.Bper] == 0
assert table[0, a.Iloc, a.Bloc] == 0
assert table[0, a.Iloc, a.Borg] == 0
assert table[0, a.Iloc, a.Iper] == 0
assert table[0, a.Iloc, a.Iloc] == 1
assert table[0, a.Iloc, a.Iorg] == 0
assert table[0, a.Iloc, a.Lper] == 0
assert table[0, a.Iloc, a.Lloc] == 1
assert table[0, a.Iloc, a.Lorg] == 0
assert table[0, a.Iloc, a.Uper] == 0
assert table[0, a.Iloc, a.Uloc] == 0
assert table[0, a.Iloc, a.Uorg] == 0
assert table[0, a.Iloc, a.O] == 0
assert table[0, a.Iorg, a.Bper] == 0
assert table[0, a.Iorg, a.Bloc] == 0
assert table[0, a.Iorg, a.Borg] == 0
assert table[0, a.Iorg, a.Iper] == 0
assert table[0, a.Iorg, a.Iloc] == 0
assert table[0, a.Iorg, a.Iorg] == 1
assert table[0, a.Iorg, a.Lper] == 0
assert table[0, a.Iorg, a.Lloc] == 0
assert table[0, a.Iorg, a.Lorg] == 1
assert table[0, a.Iorg, a.Uper] == 0
assert table[0, a.Iorg, a.Uloc] == 0
assert table[0, a.Iorg, a.Uorg] == 0
assert table[0, a.Iorg, a.O] == 0
# Not last token, prev action was L
assert table[0, a.Lper, a.Bper] == 1
assert table[0, a.Lper, a.Bloc] == 1
assert table[0, a.Lper, a.Borg] == 1
assert table[0, a.Lper, a.Iper] == 0
assert table[0, a.Lper, a.Iloc] == 0
assert table[0, a.Lper, a.Iorg] == 0
assert table[0, a.Lper, a.Lper] == 0
assert table[0, a.Lper, a.Lloc] == 0
assert table[0, a.Lper, a.Lorg] == 0
assert table[0, a.Lper, a.Uper] == 1
assert table[0, a.Lper, a.Uloc] == 1
assert table[0, a.Lper, a.Uorg] == 1
assert table[0, a.Lper, a.O] == 1
assert table[0, a.Lloc, a.Bper] == 1
assert table[0, a.Lloc, a.Bloc] == 1
assert table[0, a.Lloc, a.Borg] == 1
assert table[0, a.Lloc, a.Iper] == 0
assert table[0, a.Lloc, a.Iloc] == 0
assert table[0, a.Lloc, a.Iorg] == 0
assert table[0, a.Lloc, a.Lper] == 0
assert table[0, a.Lloc, a.Lloc] == 0
assert table[0, a.Lloc, a.Lorg] == 0
assert table[0, a.Lloc, a.Uper] == 1
assert table[0, a.Lloc, a.Uloc] == 1
assert table[0, a.Lloc, a.Uorg] == 1
assert table[0, a.Lloc, a.O] == 1
assert table[0, a.Lorg, a.Bper] == 1
assert table[0, a.Lorg, a.Bloc] == 1
assert table[0, a.Lorg, a.Borg] == 1
assert table[0, a.Lorg, a.Iper] == 0
assert table[0, a.Lorg, a.Iloc] == 0
assert table[0, a.Lorg, a.Iorg] == 0
assert table[0, a.Lorg, a.Lper] == 0
assert table[0, a.Lorg, a.Lloc] == 0
assert table[0, a.Lorg, a.Lorg] == 0
assert table[0, a.Lorg, a.Uper] == 1
assert table[0, a.Lorg, a.Uloc] == 1
assert table[0, a.Lorg, a.Uorg] == 1
assert table[0, a.Lorg, a.O] == 1
# Not last token, prev action was U
assert table[0, a.Uper, a.Bper] == 1
assert table[0, a.Uper, a.Bloc] == 1
assert table[0, a.Uper, a.Borg] == 1
assert table[0, a.Uper, a.Iper] == 0
assert table[0, a.Uper, a.Iloc] == 0
assert table[0, a.Uper, a.Iorg] == 0
assert table[0, a.Uper, a.Lper] == 0
assert table[0, a.Uper, a.Lloc] == 0
assert table[0, a.Uper, a.Lorg] == 0
assert table[0, a.Uper, a.Uper] == 1
assert table[0, a.Uper, a.Uloc] == 1
assert table[0, a.Uper, a.Uorg] == 1
assert table[0, a.Uper, a.O] == 1
assert table[0, a.Uloc, a.Bper] == 1
assert table[0, a.Uloc, a.Bloc] == 1
assert table[0, a.Uloc, a.Borg] == 1
assert table[0, a.Uloc, a.Iper] == 0
assert table[0, a.Uloc, a.Iloc] == 0
assert table[0, a.Uloc, a.Iorg] == 0
assert table[0, a.Uloc, a.Lper] == 0
assert table[0, a.Uloc, a.Lloc] == 0
assert table[0, a.Uloc, a.Lorg] == 0
assert table[0, a.Uloc, a.Uper] == 1
assert table[0, a.Uloc, a.Uloc] == 1
assert table[0, a.Uloc, a.Uorg] == 1
assert table[0, a.Uloc, a.O] == 1
assert table[0, a.Uorg, a.Bper] == 1
assert table[0, a.Uorg, a.Bloc] == 1
assert table[0, a.Uorg, a.Borg] == 1
assert table[0, a.Uorg, a.Iper] == 0
assert table[0, a.Uorg, a.Iloc] == 0
assert table[0, a.Uorg, a.Iorg] == 0
assert table[0, a.Uorg, a.Lper] == 0
assert table[0, a.Uorg, a.Lloc] == 0
assert table[0, a.Uorg, a.Lorg] == 0
assert table[0, a.Uorg, a.Uper] == 1
assert table[0, a.Uorg, a.Uloc] == 1
assert table[0, a.Uorg, a.Uorg] == 1
assert table[0, a.Uorg, a.O] == 1
# Not last token, prev action was O
assert table[0, a.O, a.Bper] == 1
assert table[0, a.O, a.Bloc] == 1
assert table[0, a.O, a.Borg] == 1
assert table[0, a.O, a.Iper] == 0
assert table[0, a.O, a.Iloc] == 0
assert table[0, a.O, a.Iorg] == 0
assert table[0, a.O, a.Lper] == 0
assert table[0, a.O, a.Lloc] == 0
assert table[0, a.O, a.Lorg] == 0
assert table[0, a.O, a.Uper] == 1
assert table[0, a.O, a.Uloc] == 1
assert table[0, a.O, a.Uorg] == 1
assert table[0, a.O, a.O] == 1
# Last token, prev action was B
assert table[1, a.Bper, a.Bper] == 0
assert table[1, a.Bper, a.Bloc] == 0
assert table[1, a.Bper, a.Borg] == 0
assert table[1, a.Bper, a.Iper] == 0
assert table[1, a.Bper, a.Iloc] == 0
assert table[1, a.Bper, a.Iorg] == 0
assert table[1, a.Bper, a.Lper] == 1
assert table[1, a.Bper, a.Lloc] == 0
assert table[1, a.Bper, a.Lorg] == 0
assert table[1, a.Bper, a.Uper] == 0
assert table[1, a.Bper, a.Uloc] == 0
assert table[1, a.Bper, a.Uorg] == 0
assert table[1, a.Bper, a.O] == 0
assert table[1, a.Bloc, a.Bper] == 0
assert table[1, a.Bloc, a.Bloc] == 0
assert table[0, a.Bloc, a.Borg] == 0
assert table[1, a.Bloc, a.Iper] == 0
assert table[1, a.Bloc, a.Iloc] == 0
assert table[1, a.Bloc, a.Iorg] == 0
assert table[1, a.Bloc, a.Lper] == 0
assert table[1, a.Bloc, a.Lloc] == 1
assert table[1, a.Bloc, a.Lorg] == 0
assert table[1, a.Bloc, a.Uper] == 0
assert table[1, a.Bloc, a.Uloc] == 0
assert table[1, a.Bloc, a.Uorg] == 0
assert table[1, a.Bloc, a.O] == 0
assert table[1, a.Borg, a.Bper] == 0
assert table[1, a.Borg, a.Bloc] == 0
assert table[1, a.Borg, a.Borg] == 0
assert table[1, a.Borg, a.Iper] == 0
assert table[1, a.Borg, a.Iloc] == 0
assert table[1, a.Borg, a.Iorg] == 0
assert table[1, a.Borg, a.Lper] == 0
assert table[1, a.Borg, a.Lloc] == 0
assert table[1, a.Borg, a.Lorg] == 1
assert table[1, a.Borg, a.Uper] == 0
assert table[1, a.Borg, a.Uloc] == 0
assert table[1, a.Borg, a.Uorg] == 0
assert table[1, a.Borg, a.O] == 0
# Last token, prev action was I
assert table[1, a.Iper, a.Bper] == 0
assert table[1, a.Iper, a.Bloc] == 0
assert table[1, a.Iper, a.Borg] == 0
assert table[1, a.Iper, a.Iper] == 0
assert table[1, a.Iper, a.Iloc] == 0
assert table[1, a.Iper, a.Iorg] == 0
assert table[1, a.Iper, a.Lper] == 1
assert table[1, a.Iper, a.Lloc] == 0
assert table[1, a.Iper, a.Lorg] == 0
assert table[1, a.Iper, a.Uper] == 0
assert table[1, a.Iper, a.Uloc] == 0
assert table[1, a.Iper, a.Uorg] == 0
assert table[1, a.Iper, a.O] == 0
assert table[1, a.Iloc, a.Bper] == 0
assert table[1, a.Iloc, a.Bloc] == 0
assert table[1, a.Iloc, a.Borg] == 0
assert table[1, a.Iloc, a.Iper] == 0
assert table[1, a.Iloc, a.Iloc] == 0
assert table[1, a.Iloc, a.Iorg] == 0
assert table[1, a.Iloc, a.Lper] == 0
assert table[1, a.Iloc, a.Lloc] == 1
assert table[1, a.Iloc, a.Lorg] == 0
assert table[1, a.Iloc, a.Uper] == 0
assert table[1, a.Iloc, a.Uloc] == 0
assert table[1, a.Iloc, a.Uorg] == 0
assert table[1, a.Iloc, a.O] == 0
assert table[1, a.Iorg, a.Bper] == 0
assert table[1, a.Iorg, a.Bloc] == 0
assert table[1, a.Iorg, a.Borg] == 0
assert table[1, a.Iorg, a.Iper] == 0
assert table[1, a.Iorg, a.Iloc] == 0
assert table[1, a.Iorg, a.Iorg] == 0
assert table[1, a.Iorg, a.Lper] == 0
assert table[1, a.Iorg, a.Lloc] == 0
assert table[1, a.Iorg, a.Lorg] == 1
assert table[1, a.Iorg, a.Uper] == 0
assert table[1, a.Iorg, a.Uloc] == 0
assert table[1, a.Iorg, a.Uorg] == 0
assert table[1, a.Iorg, a.O] == 0
# Last token, prev action was L
assert table[1, a.Lper, a.Bper] == 0
assert table[1, a.Lper, a.Bloc] == 0
assert table[1, a.Lper, a.Borg] == 0
assert table[1, a.Lper, a.Iper] == 0
assert table[1, a.Lper, a.Iloc] == 0
assert table[1, a.Lper, a.Iorg] == 0
assert table[1, a.Lper, a.Lper] == 0
assert table[1, a.Lper, a.Lloc] == 0
assert table[1, a.Lper, a.Lorg] == 0
assert table[1, a.Lper, a.Uper] == 1
assert table[1, a.Lper, a.Uloc] == 1
assert table[1, a.Lper, a.Uorg] == 1
assert table[1, a.Lper, a.O] == 1
assert table[1, a.Lloc, a.Bper] == 0
assert table[1, a.Lloc, a.Bloc] == 0
assert table[1, a.Lloc, a.Borg] == 0
assert table[1, a.Lloc, a.Iper] == 0
assert table[1, a.Lloc, a.Iloc] == 0
assert table[1, a.Lloc, a.Iorg] == 0
assert table[1, a.Lloc, a.Lper] == 0
assert table[1, a.Lloc, a.Lloc] == 0
assert table[1, a.Lloc, a.Lorg] == 0
assert table[1, a.Lloc, a.Uper] == 1
assert table[1, a.Lloc, a.Uloc] == 1
assert table[1, a.Lloc, a.Uorg] == 1
assert table[1, a.Lloc, a.O] == 1
assert table[1, a.Lorg, a.Bper] == 0
assert table[1, a.Lorg, a.Bloc] == 0
assert table[1, a.Lorg, a.Borg] == 0
assert table[1, a.Lorg, a.Iper] == 0
assert table[1, a.Lorg, a.Iloc] == 0
assert table[1, a.Lorg, a.Iorg] == 0
assert table[1, a.Lorg, a.Lper] == 0
assert table[1, a.Lorg, a.Lloc] == 0
assert table[1, a.Lorg, a.Lorg] == 0
assert table[1, a.Lorg, a.Uper] == 1
assert table[1, a.Lorg, a.Uloc] == 1
assert table[1, a.Lorg, a.Uorg] == 1
assert table[1, a.Lorg, a.O] == 1
# Last token, prev action was U
assert table[1, a.Uper, a.Bper] == 0
assert table[1, a.Uper, a.Bloc] == 0
assert table[1, a.Uper, a.Borg] == 0
assert table[1, a.Uper, a.Iper] == 0
assert table[1, a.Uper, a.Iloc] == 0
assert table[1, a.Uper, a.Iorg] == 0
assert table[1, a.Uper, a.Lper] == 0
assert table[1, a.Uper, a.Lloc] == 0
assert table[1, a.Uper, a.Lorg] == 0
assert table[1, a.Uper, a.Uper] == 1
assert table[1, a.Uper, a.Uloc] == 1
assert table[1, a.Uper, a.Uorg] == 1
assert table[1, a.Uper, a.O] == 1
assert table[1, a.Uloc, a.Bper] == 0
assert table[1, a.Uloc, a.Bloc] == 0
assert table[1, a.Uloc, a.Borg] == 0
assert table[1, a.Uloc, a.Iper] == 0
assert table[1, a.Uloc, a.Iloc] == 0
assert table[1, a.Uloc, a.Iorg] == 0
assert table[1, a.Uloc, a.Lper] == 0
assert table[1, a.Uloc, a.Lloc] == 0
assert table[1, a.Uloc, a.Lorg] == 0
assert table[1, a.Uloc, a.Uper] == 1
assert table[1, a.Uloc, a.Uloc] == 1
assert table[1, a.Uloc, a.Uorg] == 1
assert table[1, a.Uloc, a.O] == 1
assert table[1, a.Uorg, a.Bper] == 0
assert table[1, a.Uorg, a.Bloc] == 0
assert table[1, a.Uorg, a.Borg] == 0
assert table[1, a.Uorg, a.Iper] == 0
assert table[1, a.Uorg, a.Iloc] == 0
assert table[1, a.Uorg, a.Iorg] == 0
assert table[1, a.Uorg, a.Lper] == 0
assert table[1, a.Uorg, a.Lloc] == 0
assert table[1, a.Uorg, a.Lorg] == 0
assert table[1, a.Uorg, a.Uper] == 1
assert table[1, a.Uorg, a.Uloc] == 1
assert table[1, a.Uorg, a.Uorg] == 1
assert table[1, a.Uorg, a.O] == 1
# Last token, prev action was O
assert table[1, a.O, a.Bper] == 0
assert table[1, a.O, a.Bloc] == 0
assert table[1, a.O, a.Borg] == 0
assert table[1, a.O, a.Iper] == 0
assert table[1, a.O, a.Iloc] == 0
assert table[1, a.O, a.Iorg] == 0
assert table[1, a.O, a.Lper] == 0
assert table[1, a.O, a.Lloc] == 0
assert table[1, a.O, a.Lorg] == 0
assert table[1, a.O, a.Uper] == 1
assert table[1, a.O, a.Uloc] == 1
assert table[1, a.O, a.Uorg] == 1
assert table[1, a.O, a.O] == 1

View File

@ -34,7 +34,8 @@ def test_issue2179():
nlp2.add_pipe(nlp2.create_pipe("ner"))
assert len(nlp2.get_pipe("ner").labels) == 0
nlp2.get_pipe("ner").model.resize_output(nlp.get_pipe("ner").moves.n_moves)
model = nlp2.get_pipe("ner").model
model.attrs["resize_output"](model, nlp.get_pipe("ner").moves.n_moves)
nlp2.from_bytes(nlp.to_bytes())
assert "extra_labels" not in nlp2.get_pipe("ner").cfg
assert nlp2.get_pipe("ner").labels == ("CITIZENSHIP",)

View File

@ -104,7 +104,8 @@ def test_issue3209():
assert ner.move_names == move_names
nlp2 = English()
nlp2.add_pipe(nlp2.create_pipe("ner"))
nlp2.get_pipe("ner").model.resize_output(ner.moves.n_moves)
model = nlp2.get_pipe("ner").model
model.attrs["resize_output"](model, ner.moves.n_moves)
nlp2.from_bytes(nlp.to_bytes())
assert nlp2.get_pipe("ner").move_names == move_names

View File

@ -110,10 +110,9 @@ def test_serialize_custom_nlp():
nlp2 = spacy.load(d)
model = nlp2.get_pipe("parser").model
tok2vec = model.get_ref("tok2vec")
upper = model.upper
upper = model.get_ref("upper")
# check that we have the correct settings, not the default ones
assert tok2vec.get_dim("nO") == 321
assert upper.get_dim("nI") == 65
@ -131,8 +130,7 @@ def test_serialize_parser():
nlp2 = spacy.load(d)
model = nlp2.get_pipe("parser").model
tok2vec = model.get_ref("tok2vec")
upper = model.upper
upper = model.get_ref("upper")
# check that we have the correct settings, not the default ones
assert upper.get_dim("nI") == 66
assert tok2vec.get_dim("nO") == 333

View File

@ -63,7 +63,7 @@ def test_to_from_bytes(parser, blank_parser):
bytes_data = parser.to_bytes(exclude=["vocab"])
# the blank parser needs to be resized before we can call from_bytes
blank_parser.model.resize_output(parser.moves.n_moves)
blank_parser.model.attrs["resize_output"](blank_parser.model, parser.moves.n_moves)
blank_parser.from_bytes(bytes_data)
assert blank_parser.model is not True
assert blank_parser.moves.n_moves == parser.moves.n_moves

View File

@ -38,7 +38,7 @@ def test_util_get_package_path(package):
def test_PrecomputableAffine(nO=4, nI=5, nF=3, nP=2):
model = PrecomputableAffine(nO=nO, nI=nI, nF=nF, nP=nP)
model = PrecomputableAffine(nO=nO, nI=nI, nF=nF, nP=nP).initialize()
assert model.get_param("W").shape == (nF, nO, nP, nI)
tensor = model.ops.alloc((10, nI))
Y, get_dX = model.begin_update(tensor)

View File

@ -571,8 +571,10 @@ def decaying(start, stop, decay):
curr -= decay
def minibatch_by_words(examples, size, tuples=True, count_words=len):
"""Create minibatches of a given number of words."""
def minibatch_by_words(examples, size, tuples=True, count_words=len, tolerance=0.2):
"""Create minibatches of roughly a given number of words. If any examples
are longer than the specified batch length, they will appear in a batch by
themselves."""
if isinstance(size, int):
size_ = itertools.repeat(size)
elif isinstance(size, List):
@ -580,18 +582,36 @@ def minibatch_by_words(examples, size, tuples=True, count_words=len):
else:
size_ = size
examples = iter(examples)
oversize = []
while True:
batch_size = next(size_)
tol_size = batch_size * 0.2
batch = []
while batch_size >= 0:
if oversize:
example = oversize.pop(0)
n_words = count_words(example.doc)
batch.append(example)
batch_size -= n_words
while batch_size >= 1:
try:
example = next(examples)
except StopIteration:
if batch:
yield batch
return
batch_size -= count_words(example.doc)
batch.append(example)
if oversize:
examples = iter(oversize)
oversize = []
if batch:
yield batch
break
else:
if batch:
yield batch
return
n_words = count_words(example.doc)
if n_words < (batch_size + tol_size):
batch_size -= n_words
batch.append(example)
else:
oversize.append(example)
if batch:
yield batch