Train textcat with config (#5143)

* bring back default build_text_classifier method

* remove _set_dims_ hack in favor of proper dim inference

* add tok2vec initialize to unit test

* small fixes

* add unit test for various textcat config settings

* logistic output layer does not have nO

* fix window_size setting

* proper fix

* fix W initialization

* Update textcat training example

* Use ml_datasets
* Convert training data to `Example` format
* Use `n_texts` to set proportionate dev size

* fix _init renaming on latest thinc

* avoid setting a non-existing dim

* update to thinc==8.0.0a2

* add BOW and CNN defaults for easy testing

* various experiments with train_textcat script, fix softmax activation in textcat bow

* allow textcat train script to work on other datasets as well

* have dataset as a parameter

* train textcat from config, with example config

* add config for training textcat

* formatting

* fix exclusive_classes

* fixing BOW for GPU

* bump thinc to 8.0.0a3 (not published yet so CI will fail)

* add in link_vectors_to_models which got deleted

Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>
This commit is contained in:
Sofie Van Landeghem 2020-03-29 19:40:36 +02:00 committed by GitHub
parent ce0e538068
commit 311133e579
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 301 additions and 98 deletions

View File

@ -2,70 +2,71 @@
# coding: utf8 # coding: utf8
"""Train a convolutional neural network text classifier on the """Train a convolutional neural network text classifier on the
IMDB dataset, using the TextCategorizer component. The dataset will be loaded IMDB dataset, using the TextCategorizer component. The dataset will be loaded
automatically via Thinc's built-in dataset loader. The model is added to automatically via the package `ml_datasets`. The model is added to
spacy.pipeline, and predictions are available via `doc.cats`. For more details, spacy.pipeline, and predictions are available via `doc.cats`. For more details,
see the documentation: see the documentation:
* Training: https://spacy.io/usage/training * Training: https://spacy.io/usage/training
Compatible with: spaCy v2.0.0+ Compatible with: spaCy v3.0.0+
""" """
from __future__ import unicode_literals, print_function from __future__ import unicode_literals, print_function
import ml_datasets
import plac import plac
import random import random
from pathlib import Path from pathlib import Path
from ml_datasets import loaders
import spacy import spacy
from spacy import util
from spacy.util import minibatch, compounding from spacy.util import minibatch, compounding
from spacy.gold import Example, GoldParse
@plac.annotations( @plac.annotations(
model=("Model name. Defaults to blank 'en' model.", "option", "m", str), config_path=("Path to config file", "positional", None, Path),
output_dir=("Optional output directory", "option", "o", Path), output_dir=("Optional output directory", "option", "o", Path),
n_texts=("Number of texts to train from", "option", "t", int), n_texts=("Number of texts to train from", "option", "t", int),
n_iter=("Number of training iterations", "option", "n", int), n_iter=("Number of training iterations", "option", "n", int),
init_tok2vec=("Pretrained tok2vec weights", "option", "t2v", Path), init_tok2vec=("Pretrained tok2vec weights", "option", "t2v", Path),
dataset=("Dataset to train on (default: imdb)", "option", "d", str),
threshold=("Min. number of instances for a given label (default 20)", "option", "m", int)
) )
def main(model=None, output_dir=None, n_iter=20, n_texts=2000, init_tok2vec=None): def main(config_path, output_dir=None, n_iter=20, n_texts=2000, init_tok2vec=None, dataset="imdb", threshold=20):
if not config_path or not config_path.exists():
raise ValueError(f"Config file not found at {config_path}")
spacy.util.fix_random_seed()
if output_dir is not None: if output_dir is not None:
output_dir = Path(output_dir) output_dir = Path(output_dir)
if not output_dir.exists(): if not output_dir.exists():
output_dir.mkdir() output_dir.mkdir()
if model is not None: print(f"Loading nlp model from {config_path}")
nlp = spacy.load(model) # load existing spaCy model nlp_config = util.load_config(config_path, create_objects=False)["nlp"]
print("Loaded model '%s'" % model) nlp = util.load_model_from_config(nlp_config)
else:
nlp = spacy.blank("en") # create blank Language class
print("Created blank 'en' model")
# add the text classifier to the pipeline if it doesn't exist # ensure the nlp object was defined with a textcat component
# nlp.create_pipe works for built-ins that are registered with spaCy
if "textcat" not in nlp.pipe_names: if "textcat" not in nlp.pipe_names:
textcat = nlp.create_pipe( raise ValueError(f"The nlp definition in the config does not contain a textcat component")
"textcat", config={"exclusive_classes": True, "architecture": "simple_cnn"}
)
nlp.add_pipe(textcat, last=True)
# otherwise, get it, so we can add labels to it
else:
textcat = nlp.get_pipe("textcat")
# add label to text classifier textcat = nlp.get_pipe("textcat")
textcat.add_label("POSITIVE")
textcat.add_label("NEGATIVE")
# load the IMDB dataset # load the dataset
print("Loading IMDB data...") print(f"Loading dataset {dataset} ...")
(train_texts, train_cats), (dev_texts, dev_cats) = load_data() (train_texts, train_cats), (dev_texts, dev_cats) = load_data(dataset=dataset, threshold=threshold, limit=n_texts)
train_texts = train_texts[:n_texts]
train_cats = train_cats[:n_texts]
print( print(
"Using {} examples ({} training, {} evaluation)".format( "Using {} examples ({} training, {} evaluation)".format(
n_texts, len(train_texts), len(dev_texts) n_texts, len(train_texts), len(dev_texts)
) )
) )
train_data = list(zip(train_texts, [{"cats": cats} for cats in train_cats])) train_examples = []
for text, cats in zip(train_texts, train_cats):
doc = nlp.make_doc(text)
gold = GoldParse(doc, cats=cats)
for cat in cats:
textcat.add_label(cat)
ex = Example.from_gold(gold, doc=doc)
train_examples.append(ex)
# get names of other pipes to disable them during training # get names of other pipes to disable them during training
pipe_exceptions = ["textcat", "trf_wordpiecer", "trf_tok2vec"] pipe_exceptions = ["textcat", "trf_wordpiecer", "trf_tok2vec"]
@ -81,8 +82,8 @@ def main(model=None, output_dir=None, n_iter=20, n_texts=2000, init_tok2vec=None
for i in range(n_iter): for i in range(n_iter):
losses = {} losses = {}
# batch up the examples using spaCy's minibatch # batch up the examples using spaCy's minibatch
random.shuffle(train_data) random.shuffle(train_examples)
batches = minibatch(train_data, size=batch_sizes) batches = minibatch(train_examples, size=batch_sizes)
for batch in batches: for batch in batches:
nlp.update(batch, sgd=optimizer, drop=0.2, losses=losses) nlp.update(batch, sgd=optimizer, drop=0.2, losses=losses)
with textcat.model.use_params(optimizer.averages): with textcat.model.use_params(optimizer.averages):
@ -97,7 +98,7 @@ def main(model=None, output_dir=None, n_iter=20, n_texts=2000, init_tok2vec=None
) )
) )
# test the trained model # test the trained model (only makes sense for sentiment analysis)
test_text = "This movie sucked" test_text = "This movie sucked"
doc = nlp(test_text) doc = nlp(test_text)
print(test_text, doc.cats) print(test_text, doc.cats)
@ -114,14 +115,39 @@ def main(model=None, output_dir=None, n_iter=20, n_texts=2000, init_tok2vec=None
print(test_text, doc2.cats) print(test_text, doc2.cats)
def load_data(limit=0, split=0.8): def load_data(dataset, threshold, limit=0, split=0.8):
"""Load data from the IMDB dataset.""" """Load data from the provided dataset."""
# Partition off part of the train data for evaluation # Partition off part of the train data for evaluation
train_data, _ = ml_datasets.imdb() data_loader = loaders.get(dataset)
train_data, _ = data_loader(limit=int(limit/split))
random.shuffle(train_data) random.shuffle(train_data)
train_data = train_data[-limit:]
texts, labels = zip(*train_data) texts, labels = zip(*train_data)
cats = [{"POSITIVE": bool(y), "NEGATIVE": not bool(y)} for y in labels]
unique_labels = sorted(set([l for label_set in labels for l in label_set]))
print(f"# of unique_labels: {len(unique_labels)}")
count_values_train = dict()
for text, annot_list in train_data:
for annot in annot_list:
count_values_train[annot] = count_values_train.get(annot, 0) + 1
for value, count in sorted(count_values_train.items(), key=lambda item: item[1]):
if count < threshold:
unique_labels.remove(value)
print(f"# of unique_labels after filtering with threshold {threshold}: {len(unique_labels)}")
if unique_labels == {0, 1}:
cats = [{"POSITIVE": bool(y), "NEGATIVE": not bool(y)} for y in labels]
else:
cats = []
for y in labels:
if isinstance(y, str):
cats.append({str(label): (label == y) for label in unique_labels})
elif isinstance(y, set):
cats.append({str(label): (label in y) for label in unique_labels})
else:
raise ValueError(f"Unrecognised type of labels: {type(y)}")
split = int(len(train_data) * split) split = int(len(train_data) * split)
return (texts[:split], cats[:split]), (texts[split:], cats[split:]) return (texts[:split], cats[:split]), (texts[split:], cats[split:])

View File

@ -0,0 +1,19 @@
[nlp]
lang = "en"
[nlp.pipeline.textcat]
factory = "textcat"
[nlp.pipeline.textcat.model]
@architectures = "spacy.TextCatCNN.v1"
exclusive_classes = false
[nlp.pipeline.textcat.model.tok2vec]
@architectures = "spacy.HashEmbedCNN.v1"
pretrained_vectors = null
width = 96
depth = 4
embed_size = 2000
window_size = 1
maxout_pieces = 3
subword_features = true

View File

@ -11,26 +11,26 @@ def extract_ngrams(ngram_size, attr=LOWER) -> Model:
return model return model
def forward(self, docs, is_train: bool): def forward(model, docs, is_train: bool):
batch_keys = [] batch_keys = []
batch_vals = [] batch_vals = []
for doc in docs: for doc in docs:
unigrams = doc.to_array([self.attrs["attr"]]) unigrams = model.ops.asarray(doc.to_array([model.attrs["attr"]]))
ngrams = [unigrams] ngrams = [unigrams]
for n in range(2, self.attrs["ngram_size"] + 1): for n in range(2, model.attrs["ngram_size"] + 1):
ngrams.append(self.ops.ngrams(n, unigrams)) ngrams.append(model.ops.ngrams(n, unigrams))
keys = self.ops.xp.concatenate(ngrams) keys = model.ops.xp.concatenate(ngrams)
keys, vals = self.ops.xp.unique(keys, return_counts=True) keys, vals = model.ops.xp.unique(keys, return_counts=True)
batch_keys.append(keys) batch_keys.append(keys)
batch_vals.append(vals) batch_vals.append(vals)
# The dtype here matches what thinc is expecting -- which differs per # The dtype here matches what thinc is expecting -- which differs per
# platform (by int definition). This should be fixed once the problem # platform (by int definition). This should be fixed once the problem
# is fixed on Thinc's side. # is fixed on Thinc's side.
lengths = self.ops.asarray([arr.shape[0] for arr in batch_keys], dtype=numpy.int_) lengths = model.ops.asarray([arr.shape[0] for arr in batch_keys], dtype=numpy.int_)
batch_keys = self.ops.xp.concatenate(batch_keys) batch_keys = model.ops.xp.concatenate(batch_keys)
batch_vals = self.ops.asarray(self.ops.xp.concatenate(batch_vals), dtype="f") batch_vals = model.ops.asarray(model.ops.xp.concatenate(batch_vals), dtype="f")
def backprop(dY): def backprop(dY):
return dY return []
return (batch_keys, batch_vals, lengths), backprop return (batch_keys, batch_vals, lengths), backprop

View File

@ -0,0 +1,5 @@
[model]
@architectures = "spacy.TextCatBOW.v1"
exclusive_classes = false
ngram_size: 1
no_output_layer: false

View File

@ -0,0 +1,13 @@
[model]
@architectures = "spacy.TextCatCNN.v1"
exclusive_classes = false
[model.tok2vec]
@architectures = "spacy.HashEmbedCNN.v1"
pretrained_vectors = null
width = 96
depth = 4
embed_size = 2000
window_size = 1
maxout_pieces = 3
subword_features = true

View File

@ -1,13 +1,9 @@
[model] [model]
@architectures = "spacy.TextCatCNN.v1" @architectures = "spacy.TextCat.v1"
exclusive_classes = false exclusive_classes = false
[model.tok2vec]
@architectures = "spacy.HashEmbedCNN.v1"
pretrained_vectors = null pretrained_vectors = null
width = 96 width = 64
depth = 4 conv_depth = 2
embed_size = 2000 embed_size = 2000
window_size = 1 window_size = 1
maxout_pieces = 3 ngram_size = 1
subword_features = true

View File

@ -2,7 +2,7 @@ from pydantic import StrictInt
from thinc.api import Model, chain, list2array, Linear, zero_init, use_ops from thinc.api import Model, chain, list2array, Linear, zero_init, use_ops
from ...util import registry from ...util import registry
from .._layers import PrecomputableAffine from .._precomputable_affine import PrecomputableAffine
from ...syntax._parser_model import ParserModel from ...syntax._parser_model import ParserModel

View File

@ -1,7 +1,11 @@
from thinc.api import Model, chain, reduce_mean, Linear, list2ragged, Logistic from thinc.api import Model, reduce_mean, Linear, list2ragged, Logistic, ParametricAttention
from thinc.api import SparseLinear, Softmax from thinc.api import chain, concatenate, clone, Dropout
from thinc.api import SparseLinear, Softmax, softmax_activation, Maxout, reduce_sum, Relu, residual, expand_window
from thinc.api import HashEmbed, with_ragged, with_array, with_cpu, uniqued, FeatureExtractor
from ...attrs import ORTH from ..spacy_vectors import SpacyVectors
from ... import util
from ...attrs import ID, ORTH, NORM, PREFIX, SUFFIX, SHAPE, LOWER
from ...util import registry from ...util import registry
from ..extract_ngrams import extract_ngrams from ..extract_ngrams import extract_ngrams
@ -20,7 +24,6 @@ def build_simple_cnn_text_classifier(tok2vec, exclusive_classes, nO=None):
model = tok2vec >> list2ragged() >> reduce_mean() >> output_layer model = tok2vec >> list2ragged() >> reduce_mean() >> output_layer
model.set_ref("output_layer", output_layer) model.set_ref("output_layer", output_layer)
else: else:
# TODO: experiment with init_w=zero_init
linear_layer = Linear(nO=nO, nI=tok2vec.get_dim("nO")) linear_layer = Linear(nO=nO, nI=tok2vec.get_dim("nO"))
model = ( model = (
tok2vec >> list2ragged() >> reduce_mean() >> linear_layer >> Logistic() tok2vec >> list2ragged() >> reduce_mean() >> linear_layer >> Logistic()
@ -33,13 +36,100 @@ def build_simple_cnn_text_classifier(tok2vec, exclusive_classes, nO=None):
@registry.architectures.register("spacy.TextCatBOW.v1") @registry.architectures.register("spacy.TextCatBOW.v1")
def build_bow_text_classifier(exclusive_classes, ngram_size, no_output_layer, nO=None): def build_bow_text_classifier(exclusive_classes, ngram_size, no_output_layer, nO=None):
# Note: original defaults were ngram_size=1 and no_output_layer=False
with Model.define_operators({">>": chain}): with Model.define_operators({">>": chain}):
model = extract_ngrams(ngram_size, attr=ORTH) >> SparseLinear(nO) sparse_linear = SparseLinear(nO)
model.to_cpu() model = extract_ngrams(ngram_size, attr=ORTH) >> sparse_linear
model = with_cpu(model, model.ops)
if not no_output_layer: if not no_output_layer:
output_layer = Softmax(nO) if exclusive_classes else Logistic(nO) output_layer = softmax_activation() if exclusive_classes else Logistic()
output_layer.to_cpu() model = model >> with_cpu(output_layer, output_layer.ops)
model = model >> output_layer model.set_ref("output_layer", sparse_linear)
model.set_ref("output_layer", output_layer) return model
@registry.architectures.register("spacy.TextCat.v1")
def build_text_classifier(width, embed_size, pretrained_vectors, exclusive_classes, ngram_size,
window_size, conv_depth, nO=None):
cols = [ORTH, LOWER, PREFIX, SUFFIX, SHAPE, ID]
with Model.define_operators({">>": chain, "|": concatenate, "**": clone}):
lower = HashEmbed(nO=width, nV=embed_size, column=cols.index(LOWER))
prefix = HashEmbed(nO=width // 2, nV=embed_size, column=cols.index(PREFIX))
suffix = HashEmbed(nO=width // 2, nV=embed_size, column=cols.index(SUFFIX))
shape = HashEmbed(nO=width // 2, nV=embed_size, column=cols.index(SHAPE))
width_nI = sum(layer.get_dim("nO") for layer in [lower, prefix, suffix, shape])
trained_vectors = FeatureExtractor(cols) >> with_array(
uniqued(
(lower | prefix | suffix | shape)
>> Maxout(nO=width, nI=width_nI, normalize=True),
column=cols.index(ORTH),
)
)
if pretrained_vectors:
nlp = util.load_model(pretrained_vectors)
vectors = nlp.vocab.vectors
vector_dim = vectors.data.shape[1]
static_vectors = SpacyVectors(vectors) >> with_array(
Linear(width, vector_dim)
)
vector_layer = trained_vectors | static_vectors
vectors_width = width * 2
else:
vector_layer = trained_vectors
vectors_width = width
tok2vec = vector_layer >> with_array(
Maxout(width, vectors_width, normalize=True)
>> residual((expand_window(window_size=window_size)
>> Maxout(nO=width, nI=width * ((window_size * 2) + 1), normalize=True))) ** conv_depth,
pad=conv_depth,
)
cnn_model = (
tok2vec
>> list2ragged()
>> ParametricAttention(width)
>> reduce_sum()
>> residual(Maxout(nO=width, nI=width))
>> Linear(nO=nO, nI=width)
>> Dropout(0.0)
)
linear_model = build_bow_text_classifier(
nO=nO, ngram_size=ngram_size, exclusive_classes=exclusive_classes, no_output_layer=False
)
nO_double = nO*2 if nO else None
if exclusive_classes:
output_layer = Softmax(nO=nO, nI=nO_double)
else:
output_layer = (
Linear(nO=nO, nI=nO_double) >> Dropout(0.0) >> Logistic()
)
model = (linear_model | cnn_model) >> output_layer
model.set_ref("tok2vec", tok2vec)
if model.has_dim("nO") is not False:
model.set_dim("nO", nO)
model.set_ref("output_layer", linear_model.get_ref("output_layer"))
return model
@registry.architectures.register("spacy.TextCatLowData.v1")
def build_text_classifier_lowdata(width, pretrained_vectors, nO=None):
nlp = util.load_model(pretrained_vectors)
vectors = nlp.vocab.vectors
vector_dim = vectors.data.shape[1]
# Note, before v.3, this was the default if setting "low_data" and "pretrained_dims"
with Model.define_operators({">>": chain, "**": clone}):
model = (
SpacyVectors(vectors)
>> list2ragged()
>> with_ragged(0, Linear(width, vector_dim))
>> ParametricAttention(width)
>> reduce_sum()
>> residual(Relu(width, width)) ** 2
>> Linear(nO, width)
>> Dropout(0.0)
>> Logistic()
)
return model return model

View File

@ -28,8 +28,6 @@ def Tok2Vec(extract, embed, encode):
if encode.attrs.get("receptive_field", None): if encode.attrs.get("receptive_field", None):
field_size = encode.attrs["receptive_field"] field_size = encode.attrs["receptive_field"]
with Model.define_operators({">>": chain, "|": concatenate}): with Model.define_operators({">>": chain, "|": concatenate}):
if extract.has_dim("nO"):
_set_dims(embed, "nI", extract.get_dim("nO"))
tok2vec = extract >> with_array(embed >> encode, pad=field_size) tok2vec = extract >> with_array(embed >> encode, pad=field_size)
tok2vec.set_dim("nO", encode.get_dim("nO")) tok2vec.set_dim("nO", encode.get_dim("nO"))
tok2vec.set_ref("embed", embed) tok2vec.set_ref("embed", embed)
@ -176,18 +174,11 @@ def MultiHashEmbed(columns, width, rows, use_subwords, pretrained_vectors, mix):
nr_columns = 2 nr_columns = 2
concat_columns = glove | norm concat_columns = glove | norm
_set_dims(mix, "nI", width * nr_columns)
embed_layer = uniqued(concat_columns >> mix, column=columns.index("ORTH")) embed_layer = uniqued(concat_columns >> mix, column=columns.index("ORTH"))
return embed_layer return embed_layer
def _set_dims(model, name, value):
# Loop through the model to set a specific dimension if its unset on any layer.
for node in model.walk():
if node.has_dim(name) is None:
node.set_dim(name, value)
@registry.architectures.register("spacy.CharacterEmbed.v1") @registry.architectures.register("spacy.CharacterEmbed.v1")
def CharacterEmbed(columns, width, rows, nM, nC, features): def CharacterEmbed(columns, width, rows, nM, nC, features):
norm = HashEmbed(nO=width, nV=rows, column=columns.index("NORM")) norm = HashEmbed(nO=width, nV=rows, column=columns.index("NORM"))
@ -344,6 +335,7 @@ def build_Tok2Vec_model(
tok2vec = tok2vec >> PyTorchLSTM( tok2vec = tok2vec >> PyTorchLSTM(
nO=width, nI=width, depth=bilstm_depth, bi=True nO=width, nI=width, depth=bilstm_depth, bi=True
) )
tok2vec.set_dim("nO", width) if tok2vec.has_dim("nO") is not False:
tok2vec.set_dim("nO", width)
tok2vec.set_ref("embed", embed) tok2vec.set_ref("embed", embed)
return tok2vec return tok2vec

27
spacy/ml/spacy_vectors.py Normal file
View File

@ -0,0 +1,27 @@
import numpy
from thinc.api import Model, Unserializable
def SpacyVectors(vectors) -> Model:
attrs = {"vectors": Unserializable(vectors)}
model = Model("spacy_vectors", forward, attrs=attrs)
return model
def forward(model, docs, is_train: bool):
batch = []
vectors = model.attrs["vectors"].obj
for doc in docs:
indices = numpy.zeros((len(doc),), dtype="i")
for i, word in enumerate(doc):
if word.orth in vectors.key2row:
indices[i] = vectors.key2row[word.orth]
else:
indices[i] = 0
batch_vectors = vectors.data[indices]
batch.append(batch_vectors)
def backprop(dY):
return None
return batch, backprop

View File

@ -148,7 +148,8 @@ class Pipe(object):
return sgd return sgd
def set_output(self, nO): def set_output(self, nO):
self.model.set_dim("nO", nO) if self.model.has_dim("nO") is not False:
self.model.set_dim("nO", nO)
if self.model.has_ref("output_layer"): if self.model.has_ref("output_layer"):
self.model.get_ref("output_layer").set_dim("nO", nO) self.model.get_ref("output_layer").set_dim("nO", nO)
@ -1133,6 +1134,7 @@ class TextCategorizer(Pipe):
docs = [Doc(Vocab(), words=["hello"])] docs = [Doc(Vocab(), words=["hello"])]
truths, _ = self._examples_to_truth(examples) truths, _ = self._examples_to_truth(examples)
self.set_output(len(self.labels)) self.set_output(len(self.labels))
link_vectors_to_models(self.vocab)
self.model.initialize(X=docs, Y=truths) self.model.initialize(X=docs, Y=truths)
if sgd is None: if sgd is None:
sgd = self.create_optimizer() sgd = self.create_optimizer()

View File

@ -131,10 +131,8 @@ class Tok2Vec(Pipe):
get_examples (function): Function returning example training data. get_examples (function): Function returning example training data.
pipeline (list): The pipeline the model is part of. pipeline (list): The pipeline the model is part of.
""" """
# TODO: charembed does not play nicely with dim inference yet docs = [Doc(Vocab(), words=["hello"])]
# docs = [Doc(Vocab(), words=["hello"])] self.model.initialize(X=docs)
# self.model.initialize(X=docs)
self.model.initialize()
link_vectors_to_models(self.vocab) link_vectors_to_models(self.vocab)

View File

@ -6,10 +6,12 @@ from spacy import util
from spacy.lang.en import English from spacy.lang.en import English
from spacy.language import Language from spacy.language import Language
from spacy.pipeline import TextCategorizer from spacy.pipeline import TextCategorizer
from spacy.tests.util import make_tempdir
from spacy.tokens import Doc from spacy.tokens import Doc
from spacy.gold import GoldParse from spacy.gold import GoldParse
from ..util import make_tempdir
from ...ml.models.defaults import default_tok2vec
TRAIN_DATA = [ TRAIN_DATA = [
("I'm so happy.", {"cats": {"POSITIVE": 1.0, "NEGATIVE": 0.0}}), ("I'm so happy.", {"cats": {"POSITIVE": 1.0, "NEGATIVE": 0.0}}),
("I'm so angry", {"cats": {"POSITIVE": 0.0, "NEGATIVE": 1.0}}), ("I'm so angry", {"cats": {"POSITIVE": 0.0, "NEGATIVE": 1.0}}),
@ -109,3 +111,33 @@ def test_overfitting_IO():
cats2 = doc2.cats cats2 = doc2.cats
assert cats2["POSITIVE"] > 0.9 assert cats2["POSITIVE"] > 0.9
assert cats2["POSITIVE"] + cats2["NEGATIVE"] == pytest.approx(1.0, 0.1) assert cats2["POSITIVE"] + cats2["NEGATIVE"] == pytest.approx(1.0, 0.1)
# fmt: off
@pytest.mark.parametrize(
"textcat_config",
[
{"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": False, "ngram_size": 1, "no_output_layer": False},
{"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": True, "ngram_size": 4, "no_output_layer": False},
{"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": False, "ngram_size": 3, "no_output_layer": True},
{"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": True, "ngram_size": 2, "no_output_layer": True},
{"@architectures": "spacy.TextCat.v1", "exclusive_classes": False, "ngram_size": 1, "pretrained_vectors": False, "width": 64, "conv_depth": 2, "embed_size": 2000, "window_size": 2},
{"@architectures": "spacy.TextCat.v1", "exclusive_classes": True, "ngram_size": 5, "pretrained_vectors": False, "width": 128, "conv_depth": 2, "embed_size": 2000, "window_size": 1},
{"@architectures": "spacy.TextCat.v1", "exclusive_classes": True, "ngram_size": 2, "pretrained_vectors": False, "width": 32, "conv_depth": 3, "embed_size": 500, "window_size": 3},
{"@architectures": "spacy.TextCatCNN.v1", "tok2vec": default_tok2vec(), "exclusive_classes": True},
{"@architectures": "spacy.TextCatCNN.v1", "tok2vec": default_tok2vec(), "exclusive_classes": False},
],
)
# fmt: on
def test_textcat_configs(textcat_config):
pipe_config = {"model": textcat_config}
nlp = English()
textcat = nlp.create_pipe("textcat", pipe_config)
for _, annotations in TRAIN_DATA:
for label, value in annotations.get("cats").items():
textcat.add_label(label)
nlp.add_pipe(textcat)
optimizer = nlp.begin_training()
for i in range(5):
losses = {}
nlp.update(TRAIN_DATA, sgd=optimizer, losses=losses)

View File

@ -4,8 +4,7 @@ import ctypes
from pathlib import Path from pathlib import Path
from spacy import util from spacy import util
from spacy import prefer_gpu, require_gpu from spacy import prefer_gpu, require_gpu
from spacy.ml._layers import PrecomputableAffine from spacy.ml._precomputable_affine import PrecomputableAffine, _backprop_precomputable_affine_padding
from spacy.ml._layers import _backprop_precomputable_affine_padding
@pytest.fixture @pytest.fixture

View File

@ -4,18 +4,7 @@ from spacy.ml.models.tok2vec import build_Tok2Vec_model
from spacy.vocab import Vocab from spacy.vocab import Vocab
from spacy.tokens import Doc from spacy.tokens import Doc
from .util import get_batch
def get_batch(batch_size):
vocab = Vocab()
docs = []
start = 0
for size in range(1, batch_size + 1):
# Make the words numbers, so that they're distinct
# across the batch, and easy to track.
numbers = [str(i) for i in range(start, start + size)]
docs.append(Doc(vocab, words=numbers))
start += size
return docs
# This fails in Thinc v7.3.1. Need to push patch # This fails in Thinc v7.3.1. Need to push patch
@ -75,7 +64,7 @@ def test_tok2vec_batch_sizes(batch_size, width, embed_size):
def test_tok2vec_configs(tok2vec_config): def test_tok2vec_configs(tok2vec_config):
docs = get_batch(3) docs = get_batch(3)
tok2vec = build_Tok2Vec_model(**tok2vec_config) tok2vec = build_Tok2Vec_model(**tok2vec_config)
tok2vec.initialize() tok2vec.initialize(docs)
vectors, backprop = tok2vec.begin_update(docs) vectors, backprop = tok2vec.begin_update(docs)
assert len(vectors) == len(docs) assert len(vectors) == len(docs)
assert vectors[0].shape == (len(docs[0]), tok2vec_config["width"]) assert vectors[0].shape == (len(docs[0]), tok2vec_config["width"])

View File

@ -9,6 +9,8 @@ from spacy import Errors
from spacy.tokens import Doc, Span from spacy.tokens import Doc, Span
from spacy.attrs import POS, TAG, HEAD, DEP, LEMMA from spacy.attrs import POS, TAG, HEAD, DEP, LEMMA
from spacy.vocab import Vocab
@contextlib.contextmanager @contextlib.contextmanager
def make_tempfile(mode="r"): def make_tempfile(mode="r"):
@ -77,6 +79,19 @@ def get_doc(
return doc return doc
def get_batch(batch_size):
vocab = Vocab()
docs = []
start = 0
for size in range(1, batch_size + 1):
# Make the words numbers, so that they're distinct
# across the batch, and easy to track.
numbers = [str(i) for i in range(start, start + size)]
docs.append(Doc(vocab, words=numbers))
start += size
return docs
def apply_transition_sequence(parser, doc, sequence): def apply_transition_sequence(parser, doc, sequence):
"""Perform a series of pre-specified transitions, to put the parser in a """Perform a series of pre-specified transitions, to put the parser in a
desired state.""" desired state."""