diff --git a/examples/training/train_textcat.py b/examples/training/train_textcat.py index 50c852ac1..dfb95b038 100644 --- a/examples/training/train_textcat.py +++ b/examples/training/train_textcat.py @@ -2,70 +2,71 @@ # coding: utf8 """Train a convolutional neural network text classifier on the IMDB dataset, using the TextCategorizer component. The dataset will be loaded -automatically via Thinc's built-in dataset loader. The model is added to +automatically via the package `ml_datasets`. The model is added to spacy.pipeline, and predictions are available via `doc.cats`. For more details, see the documentation: * Training: https://spacy.io/usage/training -Compatible with: spaCy v2.0.0+ +Compatible with: spaCy v3.0.0+ """ from __future__ import unicode_literals, print_function -import ml_datasets import plac import random from pathlib import Path +from ml_datasets import loaders import spacy +from spacy import util from spacy.util import minibatch, compounding +from spacy.gold import Example, GoldParse @plac.annotations( - model=("Model name. Defaults to blank 'en' model.", "option", "m", str), + config_path=("Path to config file", "positional", None, Path), output_dir=("Optional output directory", "option", "o", Path), n_texts=("Number of texts to train from", "option", "t", int), n_iter=("Number of training iterations", "option", "n", int), init_tok2vec=("Pretrained tok2vec weights", "option", "t2v", Path), + dataset=("Dataset to train on (default: imdb)", "option", "d", str), + threshold=("Min. number of instances for a given label (default 20)", "option", "m", int) ) -def main(model=None, output_dir=None, n_iter=20, n_texts=2000, init_tok2vec=None): +def main(config_path, output_dir=None, n_iter=20, n_texts=2000, init_tok2vec=None, dataset="imdb", threshold=20): + if not config_path or not config_path.exists(): + raise ValueError(f"Config file not found at {config_path}") + + spacy.util.fix_random_seed() if output_dir is not None: output_dir = Path(output_dir) if not output_dir.exists(): output_dir.mkdir() - if model is not None: - nlp = spacy.load(model) # load existing spaCy model - print("Loaded model '%s'" % model) - else: - nlp = spacy.blank("en") # create blank Language class - print("Created blank 'en' model") + print(f"Loading nlp model from {config_path}") + nlp_config = util.load_config(config_path, create_objects=False)["nlp"] + nlp = util.load_model_from_config(nlp_config) - # add the text classifier to the pipeline if it doesn't exist - # nlp.create_pipe works for built-ins that are registered with spaCy + # ensure the nlp object was defined with a textcat component if "textcat" not in nlp.pipe_names: - textcat = nlp.create_pipe( - "textcat", config={"exclusive_classes": True, "architecture": "simple_cnn"} - ) - nlp.add_pipe(textcat, last=True) - # otherwise, get it, so we can add labels to it - else: - textcat = nlp.get_pipe("textcat") + raise ValueError(f"The nlp definition in the config does not contain a textcat component") - # add label to text classifier - textcat.add_label("POSITIVE") - textcat.add_label("NEGATIVE") + textcat = nlp.get_pipe("textcat") - # load the IMDB dataset - print("Loading IMDB data...") - (train_texts, train_cats), (dev_texts, dev_cats) = load_data() - train_texts = train_texts[:n_texts] - train_cats = train_cats[:n_texts] + # load the dataset + print(f"Loading dataset {dataset} ...") + (train_texts, train_cats), (dev_texts, dev_cats) = load_data(dataset=dataset, threshold=threshold, limit=n_texts) print( "Using {} examples ({} training, {} evaluation)".format( n_texts, len(train_texts), len(dev_texts) ) ) - train_data = list(zip(train_texts, [{"cats": cats} for cats in train_cats])) + train_examples = [] + for text, cats in zip(train_texts, train_cats): + doc = nlp.make_doc(text) + gold = GoldParse(doc, cats=cats) + for cat in cats: + textcat.add_label(cat) + ex = Example.from_gold(gold, doc=doc) + train_examples.append(ex) # get names of other pipes to disable them during training pipe_exceptions = ["textcat", "trf_wordpiecer", "trf_tok2vec"] @@ -81,8 +82,8 @@ def main(model=None, output_dir=None, n_iter=20, n_texts=2000, init_tok2vec=None for i in range(n_iter): losses = {} # batch up the examples using spaCy's minibatch - random.shuffle(train_data) - batches = minibatch(train_data, size=batch_sizes) + random.shuffle(train_examples) + batches = minibatch(train_examples, size=batch_sizes) for batch in batches: nlp.update(batch, sgd=optimizer, drop=0.2, losses=losses) with textcat.model.use_params(optimizer.averages): @@ -97,7 +98,7 @@ def main(model=None, output_dir=None, n_iter=20, n_texts=2000, init_tok2vec=None ) ) - # test the trained model + # test the trained model (only makes sense for sentiment analysis) test_text = "This movie sucked" doc = nlp(test_text) print(test_text, doc.cats) @@ -114,14 +115,39 @@ def main(model=None, output_dir=None, n_iter=20, n_texts=2000, init_tok2vec=None print(test_text, doc2.cats) -def load_data(limit=0, split=0.8): - """Load data from the IMDB dataset.""" +def load_data(dataset, threshold, limit=0, split=0.8): + """Load data from the provided dataset.""" # Partition off part of the train data for evaluation - train_data, _ = ml_datasets.imdb() + data_loader = loaders.get(dataset) + train_data, _ = data_loader(limit=int(limit/split)) random.shuffle(train_data) - train_data = train_data[-limit:] texts, labels = zip(*train_data) - cats = [{"POSITIVE": bool(y), "NEGATIVE": not bool(y)} for y in labels] + + unique_labels = sorted(set([l for label_set in labels for l in label_set])) + print(f"# of unique_labels: {len(unique_labels)}") + + count_values_train = dict() + for text, annot_list in train_data: + for annot in annot_list: + count_values_train[annot] = count_values_train.get(annot, 0) + 1 + for value, count in sorted(count_values_train.items(), key=lambda item: item[1]): + if count < threshold: + unique_labels.remove(value) + + print(f"# of unique_labels after filtering with threshold {threshold}: {len(unique_labels)}") + + if unique_labels == {0, 1}: + cats = [{"POSITIVE": bool(y), "NEGATIVE": not bool(y)} for y in labels] + else: + cats = [] + for y in labels: + if isinstance(y, str): + cats.append({str(label): (label == y) for label in unique_labels}) + elif isinstance(y, set): + cats.append({str(label): (label in y) for label in unique_labels}) + else: + raise ValueError(f"Unrecognised type of labels: {type(y)}") + split = int(len(train_data) * split) return (texts[:split], cats[:split]), (texts[split:], cats[split:]) diff --git a/examples/training/train_textcat_config.cfg b/examples/training/train_textcat_config.cfg new file mode 100644 index 000000000..7c0f36b57 --- /dev/null +++ b/examples/training/train_textcat_config.cfg @@ -0,0 +1,19 @@ +[nlp] +lang = "en" + +[nlp.pipeline.textcat] +factory = "textcat" + +[nlp.pipeline.textcat.model] +@architectures = "spacy.TextCatCNN.v1" +exclusive_classes = false + +[nlp.pipeline.textcat.model.tok2vec] +@architectures = "spacy.HashEmbedCNN.v1" +pretrained_vectors = null +width = 96 +depth = 4 +embed_size = 2000 +window_size = 1 +maxout_pieces = 3 +subword_features = true diff --git a/spacy/ml/_layers.py b/spacy/ml/_precomputable_affine.py similarity index 100% rename from spacy/ml/_layers.py rename to spacy/ml/_precomputable_affine.py diff --git a/spacy/ml/extract_ngrams.py b/spacy/ml/extract_ngrams.py index d4195b9a4..f9f691aae 100644 --- a/spacy/ml/extract_ngrams.py +++ b/spacy/ml/extract_ngrams.py @@ -11,26 +11,26 @@ def extract_ngrams(ngram_size, attr=LOWER) -> Model: return model -def forward(self, docs, is_train: bool): +def forward(model, docs, is_train: bool): batch_keys = [] batch_vals = [] for doc in docs: - unigrams = doc.to_array([self.attrs["attr"]]) + unigrams = model.ops.asarray(doc.to_array([model.attrs["attr"]])) ngrams = [unigrams] - for n in range(2, self.attrs["ngram_size"] + 1): - ngrams.append(self.ops.ngrams(n, unigrams)) - keys = self.ops.xp.concatenate(ngrams) - keys, vals = self.ops.xp.unique(keys, return_counts=True) + for n in range(2, model.attrs["ngram_size"] + 1): + ngrams.append(model.ops.ngrams(n, unigrams)) + keys = model.ops.xp.concatenate(ngrams) + keys, vals = model.ops.xp.unique(keys, return_counts=True) batch_keys.append(keys) batch_vals.append(vals) # The dtype here matches what thinc is expecting -- which differs per # platform (by int definition). This should be fixed once the problem # is fixed on Thinc's side. - lengths = self.ops.asarray([arr.shape[0] for arr in batch_keys], dtype=numpy.int_) - batch_keys = self.ops.xp.concatenate(batch_keys) - batch_vals = self.ops.asarray(self.ops.xp.concatenate(batch_vals), dtype="f") + lengths = model.ops.asarray([arr.shape[0] for arr in batch_keys], dtype=numpy.int_) + batch_keys = model.ops.xp.concatenate(batch_keys) + batch_vals = model.ops.asarray(model.ops.xp.concatenate(batch_vals), dtype="f") def backprop(dY): - return dY + return [] return (batch_keys, batch_vals, lengths), backprop diff --git a/spacy/ml/models/defaults/textcat_bow_defaults.cfg b/spacy/ml/models/defaults/textcat_bow_defaults.cfg new file mode 100644 index 000000000..84472ea10 --- /dev/null +++ b/spacy/ml/models/defaults/textcat_bow_defaults.cfg @@ -0,0 +1,5 @@ +[model] +@architectures = "spacy.TextCatBOW.v1" +exclusive_classes = false +ngram_size: 1 +no_output_layer: false diff --git a/spacy/ml/models/defaults/textcat_cnn_defaults.cfg b/spacy/ml/models/defaults/textcat_cnn_defaults.cfg new file mode 100644 index 000000000..cea1bfe54 --- /dev/null +++ b/spacy/ml/models/defaults/textcat_cnn_defaults.cfg @@ -0,0 +1,13 @@ +[model] +@architectures = "spacy.TextCatCNN.v1" +exclusive_classes = false + +[model.tok2vec] +@architectures = "spacy.HashEmbedCNN.v1" +pretrained_vectors = null +width = 96 +depth = 4 +embed_size = 2000 +window_size = 1 +maxout_pieces = 3 +subword_features = true diff --git a/spacy/ml/models/defaults/textcat_defaults.cfg b/spacy/ml/models/defaults/textcat_defaults.cfg index cea1bfe54..9477b2995 100644 --- a/spacy/ml/models/defaults/textcat_defaults.cfg +++ b/spacy/ml/models/defaults/textcat_defaults.cfg @@ -1,13 +1,9 @@ [model] -@architectures = "spacy.TextCatCNN.v1" +@architectures = "spacy.TextCat.v1" exclusive_classes = false - -[model.tok2vec] -@architectures = "spacy.HashEmbedCNN.v1" pretrained_vectors = null -width = 96 -depth = 4 +width = 64 +conv_depth = 2 embed_size = 2000 window_size = 1 -maxout_pieces = 3 -subword_features = true +ngram_size = 1 diff --git a/spacy/ml/models/parser.py b/spacy/ml/models/parser.py index d2de10a0e..f2d51c2ba 100644 --- a/spacy/ml/models/parser.py +++ b/spacy/ml/models/parser.py @@ -2,7 +2,7 @@ from pydantic import StrictInt from thinc.api import Model, chain, list2array, Linear, zero_init, use_ops from ...util import registry -from .._layers import PrecomputableAffine +from .._precomputable_affine import PrecomputableAffine from ...syntax._parser_model import ParserModel diff --git a/spacy/ml/models/textcat.py b/spacy/ml/models/textcat.py index 49679c8cd..ce31d058c 100644 --- a/spacy/ml/models/textcat.py +++ b/spacy/ml/models/textcat.py @@ -1,7 +1,11 @@ -from thinc.api import Model, chain, reduce_mean, Linear, list2ragged, Logistic -from thinc.api import SparseLinear, Softmax +from thinc.api import Model, reduce_mean, Linear, list2ragged, Logistic, ParametricAttention +from thinc.api import chain, concatenate, clone, Dropout +from thinc.api import SparseLinear, Softmax, softmax_activation, Maxout, reduce_sum, Relu, residual, expand_window +from thinc.api import HashEmbed, with_ragged, with_array, with_cpu, uniqued, FeatureExtractor -from ...attrs import ORTH +from ..spacy_vectors import SpacyVectors +from ... import util +from ...attrs import ID, ORTH, NORM, PREFIX, SUFFIX, SHAPE, LOWER from ...util import registry from ..extract_ngrams import extract_ngrams @@ -20,7 +24,6 @@ def build_simple_cnn_text_classifier(tok2vec, exclusive_classes, nO=None): model = tok2vec >> list2ragged() >> reduce_mean() >> output_layer model.set_ref("output_layer", output_layer) else: - # TODO: experiment with init_w=zero_init linear_layer = Linear(nO=nO, nI=tok2vec.get_dim("nO")) model = ( tok2vec >> list2ragged() >> reduce_mean() >> linear_layer >> Logistic() @@ -33,13 +36,100 @@ def build_simple_cnn_text_classifier(tok2vec, exclusive_classes, nO=None): @registry.architectures.register("spacy.TextCatBOW.v1") def build_bow_text_classifier(exclusive_classes, ngram_size, no_output_layer, nO=None): - # Note: original defaults were ngram_size=1 and no_output_layer=False with Model.define_operators({">>": chain}): - model = extract_ngrams(ngram_size, attr=ORTH) >> SparseLinear(nO) - model.to_cpu() + sparse_linear = SparseLinear(nO) + model = extract_ngrams(ngram_size, attr=ORTH) >> sparse_linear + model = with_cpu(model, model.ops) if not no_output_layer: - output_layer = Softmax(nO) if exclusive_classes else Logistic(nO) - output_layer.to_cpu() - model = model >> output_layer - model.set_ref("output_layer", output_layer) + output_layer = softmax_activation() if exclusive_classes else Logistic() + model = model >> with_cpu(output_layer, output_layer.ops) + model.set_ref("output_layer", sparse_linear) + return model + + +@registry.architectures.register("spacy.TextCat.v1") +def build_text_classifier(width, embed_size, pretrained_vectors, exclusive_classes, ngram_size, + window_size, conv_depth, nO=None): + cols = [ORTH, LOWER, PREFIX, SUFFIX, SHAPE, ID] + with Model.define_operators({">>": chain, "|": concatenate, "**": clone}): + lower = HashEmbed(nO=width, nV=embed_size, column=cols.index(LOWER)) + prefix = HashEmbed(nO=width // 2, nV=embed_size, column=cols.index(PREFIX)) + suffix = HashEmbed(nO=width // 2, nV=embed_size, column=cols.index(SUFFIX)) + shape = HashEmbed(nO=width // 2, nV=embed_size, column=cols.index(SHAPE)) + + width_nI = sum(layer.get_dim("nO") for layer in [lower, prefix, suffix, shape]) + trained_vectors = FeatureExtractor(cols) >> with_array( + uniqued( + (lower | prefix | suffix | shape) + >> Maxout(nO=width, nI=width_nI, normalize=True), + column=cols.index(ORTH), + ) + ) + + if pretrained_vectors: + nlp = util.load_model(pretrained_vectors) + vectors = nlp.vocab.vectors + vector_dim = vectors.data.shape[1] + + static_vectors = SpacyVectors(vectors) >> with_array( + Linear(width, vector_dim) + ) + vector_layer = trained_vectors | static_vectors + vectors_width = width * 2 + else: + vector_layer = trained_vectors + vectors_width = width + tok2vec = vector_layer >> with_array( + Maxout(width, vectors_width, normalize=True) + >> residual((expand_window(window_size=window_size) + >> Maxout(nO=width, nI=width * ((window_size * 2) + 1), normalize=True))) ** conv_depth, + pad=conv_depth, + ) + cnn_model = ( + tok2vec + >> list2ragged() + >> ParametricAttention(width) + >> reduce_sum() + >> residual(Maxout(nO=width, nI=width)) + >> Linear(nO=nO, nI=width) + >> Dropout(0.0) + ) + + linear_model = build_bow_text_classifier( + nO=nO, ngram_size=ngram_size, exclusive_classes=exclusive_classes, no_output_layer=False + ) + nO_double = nO*2 if nO else None + if exclusive_classes: + output_layer = Softmax(nO=nO, nI=nO_double) + else: + output_layer = ( + Linear(nO=nO, nI=nO_double) >> Dropout(0.0) >> Logistic() + ) + model = (linear_model | cnn_model) >> output_layer + model.set_ref("tok2vec", tok2vec) + if model.has_dim("nO") is not False: + model.set_dim("nO", nO) + model.set_ref("output_layer", linear_model.get_ref("output_layer")) + return model + + +@registry.architectures.register("spacy.TextCatLowData.v1") +def build_text_classifier_lowdata(width, pretrained_vectors, nO=None): + nlp = util.load_model(pretrained_vectors) + vectors = nlp.vocab.vectors + vector_dim = vectors.data.shape[1] + + # Note, before v.3, this was the default if setting "low_data" and "pretrained_dims" + with Model.define_operators({">>": chain, "**": clone}): + model = ( + SpacyVectors(vectors) + >> list2ragged() + >> with_ragged(0, Linear(width, vector_dim)) + >> ParametricAttention(width) + >> reduce_sum() + >> residual(Relu(width, width)) ** 2 + >> Linear(nO, width) + >> Dropout(0.0) + >> Logistic() + ) return model diff --git a/spacy/ml/models/tok2vec.py b/spacy/ml/models/tok2vec.py index d1a98c080..81820e56b 100644 --- a/spacy/ml/models/tok2vec.py +++ b/spacy/ml/models/tok2vec.py @@ -28,8 +28,6 @@ def Tok2Vec(extract, embed, encode): if encode.attrs.get("receptive_field", None): field_size = encode.attrs["receptive_field"] with Model.define_operators({">>": chain, "|": concatenate}): - if extract.has_dim("nO"): - _set_dims(embed, "nI", extract.get_dim("nO")) tok2vec = extract >> with_array(embed >> encode, pad=field_size) tok2vec.set_dim("nO", encode.get_dim("nO")) tok2vec.set_ref("embed", embed) @@ -176,18 +174,11 @@ def MultiHashEmbed(columns, width, rows, use_subwords, pretrained_vectors, mix): nr_columns = 2 concat_columns = glove | norm - _set_dims(mix, "nI", width * nr_columns) embed_layer = uniqued(concat_columns >> mix, column=columns.index("ORTH")) return embed_layer -def _set_dims(model, name, value): - # Loop through the model to set a specific dimension if its unset on any layer. - for node in model.walk(): - if node.has_dim(name) is None: - node.set_dim(name, value) - @registry.architectures.register("spacy.CharacterEmbed.v1") def CharacterEmbed(columns, width, rows, nM, nC, features): norm = HashEmbed(nO=width, nV=rows, column=columns.index("NORM")) @@ -344,6 +335,7 @@ def build_Tok2Vec_model( tok2vec = tok2vec >> PyTorchLSTM( nO=width, nI=width, depth=bilstm_depth, bi=True ) - tok2vec.set_dim("nO", width) + if tok2vec.has_dim("nO") is not False: + tok2vec.set_dim("nO", width) tok2vec.set_ref("embed", embed) return tok2vec diff --git a/spacy/ml/spacy_vectors.py b/spacy/ml/spacy_vectors.py new file mode 100644 index 000000000..2a4988494 --- /dev/null +++ b/spacy/ml/spacy_vectors.py @@ -0,0 +1,27 @@ +import numpy +from thinc.api import Model, Unserializable + + +def SpacyVectors(vectors) -> Model: + attrs = {"vectors": Unserializable(vectors)} + model = Model("spacy_vectors", forward, attrs=attrs) + return model + + +def forward(model, docs, is_train: bool): + batch = [] + vectors = model.attrs["vectors"].obj + for doc in docs: + indices = numpy.zeros((len(doc),), dtype="i") + for i, word in enumerate(doc): + if word.orth in vectors.key2row: + indices[i] = vectors.key2row[word.orth] + else: + indices[i] = 0 + batch_vectors = vectors.data[indices] + batch.append(batch_vectors) + + def backprop(dY): + return None + + return batch, backprop diff --git a/spacy/pipeline/pipes.pyx b/spacy/pipeline/pipes.pyx index 9ea2507cb..296ad5089 100644 --- a/spacy/pipeline/pipes.pyx +++ b/spacy/pipeline/pipes.pyx @@ -148,7 +148,8 @@ class Pipe(object): return sgd def set_output(self, nO): - self.model.set_dim("nO", nO) + if self.model.has_dim("nO") is not False: + self.model.set_dim("nO", nO) if self.model.has_ref("output_layer"): self.model.get_ref("output_layer").set_dim("nO", nO) @@ -1133,6 +1134,7 @@ class TextCategorizer(Pipe): docs = [Doc(Vocab(), words=["hello"])] truths, _ = self._examples_to_truth(examples) self.set_output(len(self.labels)) + link_vectors_to_models(self.vocab) self.model.initialize(X=docs, Y=truths) if sgd is None: sgd = self.create_optimizer() diff --git a/spacy/pipeline/tok2vec.py b/spacy/pipeline/tok2vec.py index 4623f99b0..ef744a5da 100644 --- a/spacy/pipeline/tok2vec.py +++ b/spacy/pipeline/tok2vec.py @@ -131,10 +131,8 @@ class Tok2Vec(Pipe): get_examples (function): Function returning example training data. pipeline (list): The pipeline the model is part of. """ - # TODO: charembed does not play nicely with dim inference yet - # docs = [Doc(Vocab(), words=["hello"])] - # self.model.initialize(X=docs) - self.model.initialize() + docs = [Doc(Vocab(), words=["hello"])] + self.model.initialize(X=docs) link_vectors_to_models(self.vocab) diff --git a/spacy/tests/pipeline/test_textcat.py b/spacy/tests/pipeline/test_textcat.py index 1b5ca9a4c..38c980428 100644 --- a/spacy/tests/pipeline/test_textcat.py +++ b/spacy/tests/pipeline/test_textcat.py @@ -6,10 +6,12 @@ from spacy import util from spacy.lang.en import English from spacy.language import Language from spacy.pipeline import TextCategorizer -from spacy.tests.util import make_tempdir from spacy.tokens import Doc from spacy.gold import GoldParse +from ..util import make_tempdir +from ...ml.models.defaults import default_tok2vec + TRAIN_DATA = [ ("I'm so happy.", {"cats": {"POSITIVE": 1.0, "NEGATIVE": 0.0}}), ("I'm so angry", {"cats": {"POSITIVE": 0.0, "NEGATIVE": 1.0}}), @@ -109,3 +111,33 @@ def test_overfitting_IO(): cats2 = doc2.cats assert cats2["POSITIVE"] > 0.9 assert cats2["POSITIVE"] + cats2["NEGATIVE"] == pytest.approx(1.0, 0.1) + + +# fmt: off +@pytest.mark.parametrize( + "textcat_config", + [ + {"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": False, "ngram_size": 1, "no_output_layer": False}, + {"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": True, "ngram_size": 4, "no_output_layer": False}, + {"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": False, "ngram_size": 3, "no_output_layer": True}, + {"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": True, "ngram_size": 2, "no_output_layer": True}, + {"@architectures": "spacy.TextCat.v1", "exclusive_classes": False, "ngram_size": 1, "pretrained_vectors": False, "width": 64, "conv_depth": 2, "embed_size": 2000, "window_size": 2}, + {"@architectures": "spacy.TextCat.v1", "exclusive_classes": True, "ngram_size": 5, "pretrained_vectors": False, "width": 128, "conv_depth": 2, "embed_size": 2000, "window_size": 1}, + {"@architectures": "spacy.TextCat.v1", "exclusive_classes": True, "ngram_size": 2, "pretrained_vectors": False, "width": 32, "conv_depth": 3, "embed_size": 500, "window_size": 3}, + {"@architectures": "spacy.TextCatCNN.v1", "tok2vec": default_tok2vec(), "exclusive_classes": True}, + {"@architectures": "spacy.TextCatCNN.v1", "tok2vec": default_tok2vec(), "exclusive_classes": False}, + ], +) +# fmt: on +def test_textcat_configs(textcat_config): + pipe_config = {"model": textcat_config} + nlp = English() + textcat = nlp.create_pipe("textcat", pipe_config) + for _, annotations in TRAIN_DATA: + for label, value in annotations.get("cats").items(): + textcat.add_label(label) + nlp.add_pipe(textcat) + optimizer = nlp.begin_training() + for i in range(5): + losses = {} + nlp.update(TRAIN_DATA, sgd=optimizer, losses=losses) diff --git a/spacy/tests/test_misc.py b/spacy/tests/test_misc.py index 6d4e75a31..1200407d7 100644 --- a/spacy/tests/test_misc.py +++ b/spacy/tests/test_misc.py @@ -4,8 +4,7 @@ import ctypes from pathlib import Path from spacy import util from spacy import prefer_gpu, require_gpu -from spacy.ml._layers import PrecomputableAffine -from spacy.ml._layers import _backprop_precomputable_affine_padding +from spacy.ml._precomputable_affine import PrecomputableAffine, _backprop_precomputable_affine_padding @pytest.fixture diff --git a/spacy/tests/test_tok2vec.py b/spacy/tests/test_tok2vec.py index e1ad1f0fc..9c2e9004b 100644 --- a/spacy/tests/test_tok2vec.py +++ b/spacy/tests/test_tok2vec.py @@ -4,18 +4,7 @@ from spacy.ml.models.tok2vec import build_Tok2Vec_model from spacy.vocab import Vocab from spacy.tokens import Doc - -def get_batch(batch_size): - vocab = Vocab() - docs = [] - start = 0 - for size in range(1, batch_size + 1): - # Make the words numbers, so that they're distinct - # across the batch, and easy to track. - numbers = [str(i) for i in range(start, start + size)] - docs.append(Doc(vocab, words=numbers)) - start += size - return docs +from .util import get_batch # This fails in Thinc v7.3.1. Need to push patch @@ -75,7 +64,7 @@ def test_tok2vec_batch_sizes(batch_size, width, embed_size): def test_tok2vec_configs(tok2vec_config): docs = get_batch(3) tok2vec = build_Tok2Vec_model(**tok2vec_config) - tok2vec.initialize() + tok2vec.initialize(docs) vectors, backprop = tok2vec.begin_update(docs) assert len(vectors) == len(docs) assert vectors[0].shape == (len(docs[0]), tok2vec_config["width"]) diff --git a/spacy/tests/util.py b/spacy/tests/util.py index 958d51e11..e29342268 100644 --- a/spacy/tests/util.py +++ b/spacy/tests/util.py @@ -9,6 +9,8 @@ from spacy import Errors from spacy.tokens import Doc, Span from spacy.attrs import POS, TAG, HEAD, DEP, LEMMA +from spacy.vocab import Vocab + @contextlib.contextmanager def make_tempfile(mode="r"): @@ -77,6 +79,19 @@ def get_doc( return doc +def get_batch(batch_size): + vocab = Vocab() + docs = [] + start = 0 + for size in range(1, batch_size + 1): + # Make the words numbers, so that they're distinct + # across the batch, and easy to track. + numbers = [str(i) for i in range(start, start + size)] + docs.append(Doc(vocab, words=numbers)) + start += size + return docs + + def apply_transition_sequence(parser, doc, sequence): """Perform a series of pre-specified transitions, to put the parser in a desired state."""