mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-10 19:57:17 +03:00
Train textcat with config (#5143)
* bring back default build_text_classifier method * remove _set_dims_ hack in favor of proper dim inference * add tok2vec initialize to unit test * small fixes * add unit test for various textcat config settings * logistic output layer does not have nO * fix window_size setting * proper fix * fix W initialization * Update textcat training example * Use ml_datasets * Convert training data to `Example` format * Use `n_texts` to set proportionate dev size * fix _init renaming on latest thinc * avoid setting a non-existing dim * update to thinc==8.0.0a2 * add BOW and CNN defaults for easy testing * various experiments with train_textcat script, fix softmax activation in textcat bow * allow textcat train script to work on other datasets as well * have dataset as a parameter * train textcat from config, with example config * add config for training textcat * formatting * fix exclusive_classes * fixing BOW for GPU * bump thinc to 8.0.0a3 (not published yet so CI will fail) * add in link_vectors_to_models which got deleted Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>
This commit is contained in:
parent
ce0e538068
commit
311133e579
|
@ -2,70 +2,71 @@
|
|||
# coding: utf8
|
||||
"""Train a convolutional neural network text classifier on the
|
||||
IMDB dataset, using the TextCategorizer component. The dataset will be loaded
|
||||
automatically via Thinc's built-in dataset loader. The model is added to
|
||||
automatically via the package `ml_datasets`. The model is added to
|
||||
spacy.pipeline, and predictions are available via `doc.cats`. For more details,
|
||||
see the documentation:
|
||||
* Training: https://spacy.io/usage/training
|
||||
|
||||
Compatible with: spaCy v2.0.0+
|
||||
Compatible with: spaCy v3.0.0+
|
||||
"""
|
||||
from __future__ import unicode_literals, print_function
|
||||
|
||||
import ml_datasets
|
||||
import plac
|
||||
import random
|
||||
from pathlib import Path
|
||||
from ml_datasets import loaders
|
||||
|
||||
import spacy
|
||||
from spacy import util
|
||||
from spacy.util import minibatch, compounding
|
||||
from spacy.gold import Example, GoldParse
|
||||
|
||||
|
||||
@plac.annotations(
|
||||
model=("Model name. Defaults to blank 'en' model.", "option", "m", str),
|
||||
config_path=("Path to config file", "positional", None, Path),
|
||||
output_dir=("Optional output directory", "option", "o", Path),
|
||||
n_texts=("Number of texts to train from", "option", "t", int),
|
||||
n_iter=("Number of training iterations", "option", "n", int),
|
||||
init_tok2vec=("Pretrained tok2vec weights", "option", "t2v", Path),
|
||||
dataset=("Dataset to train on (default: imdb)", "option", "d", str),
|
||||
threshold=("Min. number of instances for a given label (default 20)", "option", "m", int)
|
||||
)
|
||||
def main(model=None, output_dir=None, n_iter=20, n_texts=2000, init_tok2vec=None):
|
||||
def main(config_path, output_dir=None, n_iter=20, n_texts=2000, init_tok2vec=None, dataset="imdb", threshold=20):
|
||||
if not config_path or not config_path.exists():
|
||||
raise ValueError(f"Config file not found at {config_path}")
|
||||
|
||||
spacy.util.fix_random_seed()
|
||||
if output_dir is not None:
|
||||
output_dir = Path(output_dir)
|
||||
if not output_dir.exists():
|
||||
output_dir.mkdir()
|
||||
|
||||
if model is not None:
|
||||
nlp = spacy.load(model) # load existing spaCy model
|
||||
print("Loaded model '%s'" % model)
|
||||
else:
|
||||
nlp = spacy.blank("en") # create blank Language class
|
||||
print("Created blank 'en' model")
|
||||
print(f"Loading nlp model from {config_path}")
|
||||
nlp_config = util.load_config(config_path, create_objects=False)["nlp"]
|
||||
nlp = util.load_model_from_config(nlp_config)
|
||||
|
||||
# add the text classifier to the pipeline if it doesn't exist
|
||||
# nlp.create_pipe works for built-ins that are registered with spaCy
|
||||
# ensure the nlp object was defined with a textcat component
|
||||
if "textcat" not in nlp.pipe_names:
|
||||
textcat = nlp.create_pipe(
|
||||
"textcat", config={"exclusive_classes": True, "architecture": "simple_cnn"}
|
||||
)
|
||||
nlp.add_pipe(textcat, last=True)
|
||||
# otherwise, get it, so we can add labels to it
|
||||
else:
|
||||
textcat = nlp.get_pipe("textcat")
|
||||
raise ValueError(f"The nlp definition in the config does not contain a textcat component")
|
||||
|
||||
# add label to text classifier
|
||||
textcat.add_label("POSITIVE")
|
||||
textcat.add_label("NEGATIVE")
|
||||
textcat = nlp.get_pipe("textcat")
|
||||
|
||||
# load the IMDB dataset
|
||||
print("Loading IMDB data...")
|
||||
(train_texts, train_cats), (dev_texts, dev_cats) = load_data()
|
||||
train_texts = train_texts[:n_texts]
|
||||
train_cats = train_cats[:n_texts]
|
||||
# load the dataset
|
||||
print(f"Loading dataset {dataset} ...")
|
||||
(train_texts, train_cats), (dev_texts, dev_cats) = load_data(dataset=dataset, threshold=threshold, limit=n_texts)
|
||||
print(
|
||||
"Using {} examples ({} training, {} evaluation)".format(
|
||||
n_texts, len(train_texts), len(dev_texts)
|
||||
)
|
||||
)
|
||||
train_data = list(zip(train_texts, [{"cats": cats} for cats in train_cats]))
|
||||
train_examples = []
|
||||
for text, cats in zip(train_texts, train_cats):
|
||||
doc = nlp.make_doc(text)
|
||||
gold = GoldParse(doc, cats=cats)
|
||||
for cat in cats:
|
||||
textcat.add_label(cat)
|
||||
ex = Example.from_gold(gold, doc=doc)
|
||||
train_examples.append(ex)
|
||||
|
||||
# get names of other pipes to disable them during training
|
||||
pipe_exceptions = ["textcat", "trf_wordpiecer", "trf_tok2vec"]
|
||||
|
@ -81,8 +82,8 @@ def main(model=None, output_dir=None, n_iter=20, n_texts=2000, init_tok2vec=None
|
|||
for i in range(n_iter):
|
||||
losses = {}
|
||||
# batch up the examples using spaCy's minibatch
|
||||
random.shuffle(train_data)
|
||||
batches = minibatch(train_data, size=batch_sizes)
|
||||
random.shuffle(train_examples)
|
||||
batches = minibatch(train_examples, size=batch_sizes)
|
||||
for batch in batches:
|
||||
nlp.update(batch, sgd=optimizer, drop=0.2, losses=losses)
|
||||
with textcat.model.use_params(optimizer.averages):
|
||||
|
@ -97,7 +98,7 @@ def main(model=None, output_dir=None, n_iter=20, n_texts=2000, init_tok2vec=None
|
|||
)
|
||||
)
|
||||
|
||||
# test the trained model
|
||||
# test the trained model (only makes sense for sentiment analysis)
|
||||
test_text = "This movie sucked"
|
||||
doc = nlp(test_text)
|
||||
print(test_text, doc.cats)
|
||||
|
@ -114,14 +115,39 @@ def main(model=None, output_dir=None, n_iter=20, n_texts=2000, init_tok2vec=None
|
|||
print(test_text, doc2.cats)
|
||||
|
||||
|
||||
def load_data(limit=0, split=0.8):
|
||||
"""Load data from the IMDB dataset."""
|
||||
def load_data(dataset, threshold, limit=0, split=0.8):
|
||||
"""Load data from the provided dataset."""
|
||||
# Partition off part of the train data for evaluation
|
||||
train_data, _ = ml_datasets.imdb()
|
||||
data_loader = loaders.get(dataset)
|
||||
train_data, _ = data_loader(limit=int(limit/split))
|
||||
random.shuffle(train_data)
|
||||
train_data = train_data[-limit:]
|
||||
texts, labels = zip(*train_data)
|
||||
cats = [{"POSITIVE": bool(y), "NEGATIVE": not bool(y)} for y in labels]
|
||||
|
||||
unique_labels = sorted(set([l for label_set in labels for l in label_set]))
|
||||
print(f"# of unique_labels: {len(unique_labels)}")
|
||||
|
||||
count_values_train = dict()
|
||||
for text, annot_list in train_data:
|
||||
for annot in annot_list:
|
||||
count_values_train[annot] = count_values_train.get(annot, 0) + 1
|
||||
for value, count in sorted(count_values_train.items(), key=lambda item: item[1]):
|
||||
if count < threshold:
|
||||
unique_labels.remove(value)
|
||||
|
||||
print(f"# of unique_labels after filtering with threshold {threshold}: {len(unique_labels)}")
|
||||
|
||||
if unique_labels == {0, 1}:
|
||||
cats = [{"POSITIVE": bool(y), "NEGATIVE": not bool(y)} for y in labels]
|
||||
else:
|
||||
cats = []
|
||||
for y in labels:
|
||||
if isinstance(y, str):
|
||||
cats.append({str(label): (label == y) for label in unique_labels})
|
||||
elif isinstance(y, set):
|
||||
cats.append({str(label): (label in y) for label in unique_labels})
|
||||
else:
|
||||
raise ValueError(f"Unrecognised type of labels: {type(y)}")
|
||||
|
||||
split = int(len(train_data) * split)
|
||||
return (texts[:split], cats[:split]), (texts[split:], cats[split:])
|
||||
|
||||
|
|
19
examples/training/train_textcat_config.cfg
Normal file
19
examples/training/train_textcat_config.cfg
Normal file
|
@ -0,0 +1,19 @@
|
|||
[nlp]
|
||||
lang = "en"
|
||||
|
||||
[nlp.pipeline.textcat]
|
||||
factory = "textcat"
|
||||
|
||||
[nlp.pipeline.textcat.model]
|
||||
@architectures = "spacy.TextCatCNN.v1"
|
||||
exclusive_classes = false
|
||||
|
||||
[nlp.pipeline.textcat.model.tok2vec]
|
||||
@architectures = "spacy.HashEmbedCNN.v1"
|
||||
pretrained_vectors = null
|
||||
width = 96
|
||||
depth = 4
|
||||
embed_size = 2000
|
||||
window_size = 1
|
||||
maxout_pieces = 3
|
||||
subword_features = true
|
|
@ -11,26 +11,26 @@ def extract_ngrams(ngram_size, attr=LOWER) -> Model:
|
|||
return model
|
||||
|
||||
|
||||
def forward(self, docs, is_train: bool):
|
||||
def forward(model, docs, is_train: bool):
|
||||
batch_keys = []
|
||||
batch_vals = []
|
||||
for doc in docs:
|
||||
unigrams = doc.to_array([self.attrs["attr"]])
|
||||
unigrams = model.ops.asarray(doc.to_array([model.attrs["attr"]]))
|
||||
ngrams = [unigrams]
|
||||
for n in range(2, self.attrs["ngram_size"] + 1):
|
||||
ngrams.append(self.ops.ngrams(n, unigrams))
|
||||
keys = self.ops.xp.concatenate(ngrams)
|
||||
keys, vals = self.ops.xp.unique(keys, return_counts=True)
|
||||
for n in range(2, model.attrs["ngram_size"] + 1):
|
||||
ngrams.append(model.ops.ngrams(n, unigrams))
|
||||
keys = model.ops.xp.concatenate(ngrams)
|
||||
keys, vals = model.ops.xp.unique(keys, return_counts=True)
|
||||
batch_keys.append(keys)
|
||||
batch_vals.append(vals)
|
||||
# The dtype here matches what thinc is expecting -- which differs per
|
||||
# platform (by int definition). This should be fixed once the problem
|
||||
# is fixed on Thinc's side.
|
||||
lengths = self.ops.asarray([arr.shape[0] for arr in batch_keys], dtype=numpy.int_)
|
||||
batch_keys = self.ops.xp.concatenate(batch_keys)
|
||||
batch_vals = self.ops.asarray(self.ops.xp.concatenate(batch_vals), dtype="f")
|
||||
lengths = model.ops.asarray([arr.shape[0] for arr in batch_keys], dtype=numpy.int_)
|
||||
batch_keys = model.ops.xp.concatenate(batch_keys)
|
||||
batch_vals = model.ops.asarray(model.ops.xp.concatenate(batch_vals), dtype="f")
|
||||
|
||||
def backprop(dY):
|
||||
return dY
|
||||
return []
|
||||
|
||||
return (batch_keys, batch_vals, lengths), backprop
|
||||
|
|
5
spacy/ml/models/defaults/textcat_bow_defaults.cfg
Normal file
5
spacy/ml/models/defaults/textcat_bow_defaults.cfg
Normal file
|
@ -0,0 +1,5 @@
|
|||
[model]
|
||||
@architectures = "spacy.TextCatBOW.v1"
|
||||
exclusive_classes = false
|
||||
ngram_size: 1
|
||||
no_output_layer: false
|
13
spacy/ml/models/defaults/textcat_cnn_defaults.cfg
Normal file
13
spacy/ml/models/defaults/textcat_cnn_defaults.cfg
Normal file
|
@ -0,0 +1,13 @@
|
|||
[model]
|
||||
@architectures = "spacy.TextCatCNN.v1"
|
||||
exclusive_classes = false
|
||||
|
||||
[model.tok2vec]
|
||||
@architectures = "spacy.HashEmbedCNN.v1"
|
||||
pretrained_vectors = null
|
||||
width = 96
|
||||
depth = 4
|
||||
embed_size = 2000
|
||||
window_size = 1
|
||||
maxout_pieces = 3
|
||||
subword_features = true
|
|
@ -1,13 +1,9 @@
|
|||
[model]
|
||||
@architectures = "spacy.TextCatCNN.v1"
|
||||
@architectures = "spacy.TextCat.v1"
|
||||
exclusive_classes = false
|
||||
|
||||
[model.tok2vec]
|
||||
@architectures = "spacy.HashEmbedCNN.v1"
|
||||
pretrained_vectors = null
|
||||
width = 96
|
||||
depth = 4
|
||||
width = 64
|
||||
conv_depth = 2
|
||||
embed_size = 2000
|
||||
window_size = 1
|
||||
maxout_pieces = 3
|
||||
subword_features = true
|
||||
ngram_size = 1
|
||||
|
|
|
@ -2,7 +2,7 @@ from pydantic import StrictInt
|
|||
from thinc.api import Model, chain, list2array, Linear, zero_init, use_ops
|
||||
|
||||
from ...util import registry
|
||||
from .._layers import PrecomputableAffine
|
||||
from .._precomputable_affine import PrecomputableAffine
|
||||
from ...syntax._parser_model import ParserModel
|
||||
|
||||
|
||||
|
|
|
@ -1,7 +1,11 @@
|
|||
from thinc.api import Model, chain, reduce_mean, Linear, list2ragged, Logistic
|
||||
from thinc.api import SparseLinear, Softmax
|
||||
from thinc.api import Model, reduce_mean, Linear, list2ragged, Logistic, ParametricAttention
|
||||
from thinc.api import chain, concatenate, clone, Dropout
|
||||
from thinc.api import SparseLinear, Softmax, softmax_activation, Maxout, reduce_sum, Relu, residual, expand_window
|
||||
from thinc.api import HashEmbed, with_ragged, with_array, with_cpu, uniqued, FeatureExtractor
|
||||
|
||||
from ...attrs import ORTH
|
||||
from ..spacy_vectors import SpacyVectors
|
||||
from ... import util
|
||||
from ...attrs import ID, ORTH, NORM, PREFIX, SUFFIX, SHAPE, LOWER
|
||||
from ...util import registry
|
||||
from ..extract_ngrams import extract_ngrams
|
||||
|
||||
|
@ -20,7 +24,6 @@ def build_simple_cnn_text_classifier(tok2vec, exclusive_classes, nO=None):
|
|||
model = tok2vec >> list2ragged() >> reduce_mean() >> output_layer
|
||||
model.set_ref("output_layer", output_layer)
|
||||
else:
|
||||
# TODO: experiment with init_w=zero_init
|
||||
linear_layer = Linear(nO=nO, nI=tok2vec.get_dim("nO"))
|
||||
model = (
|
||||
tok2vec >> list2ragged() >> reduce_mean() >> linear_layer >> Logistic()
|
||||
|
@ -33,13 +36,100 @@ def build_simple_cnn_text_classifier(tok2vec, exclusive_classes, nO=None):
|
|||
|
||||
@registry.architectures.register("spacy.TextCatBOW.v1")
|
||||
def build_bow_text_classifier(exclusive_classes, ngram_size, no_output_layer, nO=None):
|
||||
# Note: original defaults were ngram_size=1 and no_output_layer=False
|
||||
with Model.define_operators({">>": chain}):
|
||||
model = extract_ngrams(ngram_size, attr=ORTH) >> SparseLinear(nO)
|
||||
model.to_cpu()
|
||||
sparse_linear = SparseLinear(nO)
|
||||
model = extract_ngrams(ngram_size, attr=ORTH) >> sparse_linear
|
||||
model = with_cpu(model, model.ops)
|
||||
if not no_output_layer:
|
||||
output_layer = Softmax(nO) if exclusive_classes else Logistic(nO)
|
||||
output_layer.to_cpu()
|
||||
model = model >> output_layer
|
||||
model.set_ref("output_layer", output_layer)
|
||||
output_layer = softmax_activation() if exclusive_classes else Logistic()
|
||||
model = model >> with_cpu(output_layer, output_layer.ops)
|
||||
model.set_ref("output_layer", sparse_linear)
|
||||
return model
|
||||
|
||||
|
||||
@registry.architectures.register("spacy.TextCat.v1")
|
||||
def build_text_classifier(width, embed_size, pretrained_vectors, exclusive_classes, ngram_size,
|
||||
window_size, conv_depth, nO=None):
|
||||
cols = [ORTH, LOWER, PREFIX, SUFFIX, SHAPE, ID]
|
||||
with Model.define_operators({">>": chain, "|": concatenate, "**": clone}):
|
||||
lower = HashEmbed(nO=width, nV=embed_size, column=cols.index(LOWER))
|
||||
prefix = HashEmbed(nO=width // 2, nV=embed_size, column=cols.index(PREFIX))
|
||||
suffix = HashEmbed(nO=width // 2, nV=embed_size, column=cols.index(SUFFIX))
|
||||
shape = HashEmbed(nO=width // 2, nV=embed_size, column=cols.index(SHAPE))
|
||||
|
||||
width_nI = sum(layer.get_dim("nO") for layer in [lower, prefix, suffix, shape])
|
||||
trained_vectors = FeatureExtractor(cols) >> with_array(
|
||||
uniqued(
|
||||
(lower | prefix | suffix | shape)
|
||||
>> Maxout(nO=width, nI=width_nI, normalize=True),
|
||||
column=cols.index(ORTH),
|
||||
)
|
||||
)
|
||||
|
||||
if pretrained_vectors:
|
||||
nlp = util.load_model(pretrained_vectors)
|
||||
vectors = nlp.vocab.vectors
|
||||
vector_dim = vectors.data.shape[1]
|
||||
|
||||
static_vectors = SpacyVectors(vectors) >> with_array(
|
||||
Linear(width, vector_dim)
|
||||
)
|
||||
vector_layer = trained_vectors | static_vectors
|
||||
vectors_width = width * 2
|
||||
else:
|
||||
vector_layer = trained_vectors
|
||||
vectors_width = width
|
||||
tok2vec = vector_layer >> with_array(
|
||||
Maxout(width, vectors_width, normalize=True)
|
||||
>> residual((expand_window(window_size=window_size)
|
||||
>> Maxout(nO=width, nI=width * ((window_size * 2) + 1), normalize=True))) ** conv_depth,
|
||||
pad=conv_depth,
|
||||
)
|
||||
cnn_model = (
|
||||
tok2vec
|
||||
>> list2ragged()
|
||||
>> ParametricAttention(width)
|
||||
>> reduce_sum()
|
||||
>> residual(Maxout(nO=width, nI=width))
|
||||
>> Linear(nO=nO, nI=width)
|
||||
>> Dropout(0.0)
|
||||
)
|
||||
|
||||
linear_model = build_bow_text_classifier(
|
||||
nO=nO, ngram_size=ngram_size, exclusive_classes=exclusive_classes, no_output_layer=False
|
||||
)
|
||||
nO_double = nO*2 if nO else None
|
||||
if exclusive_classes:
|
||||
output_layer = Softmax(nO=nO, nI=nO_double)
|
||||
else:
|
||||
output_layer = (
|
||||
Linear(nO=nO, nI=nO_double) >> Dropout(0.0) >> Logistic()
|
||||
)
|
||||
model = (linear_model | cnn_model) >> output_layer
|
||||
model.set_ref("tok2vec", tok2vec)
|
||||
if model.has_dim("nO") is not False:
|
||||
model.set_dim("nO", nO)
|
||||
model.set_ref("output_layer", linear_model.get_ref("output_layer"))
|
||||
return model
|
||||
|
||||
|
||||
@registry.architectures.register("spacy.TextCatLowData.v1")
|
||||
def build_text_classifier_lowdata(width, pretrained_vectors, nO=None):
|
||||
nlp = util.load_model(pretrained_vectors)
|
||||
vectors = nlp.vocab.vectors
|
||||
vector_dim = vectors.data.shape[1]
|
||||
|
||||
# Note, before v.3, this was the default if setting "low_data" and "pretrained_dims"
|
||||
with Model.define_operators({">>": chain, "**": clone}):
|
||||
model = (
|
||||
SpacyVectors(vectors)
|
||||
>> list2ragged()
|
||||
>> with_ragged(0, Linear(width, vector_dim))
|
||||
>> ParametricAttention(width)
|
||||
>> reduce_sum()
|
||||
>> residual(Relu(width, width)) ** 2
|
||||
>> Linear(nO, width)
|
||||
>> Dropout(0.0)
|
||||
>> Logistic()
|
||||
)
|
||||
return model
|
||||
|
|
|
@ -28,8 +28,6 @@ def Tok2Vec(extract, embed, encode):
|
|||
if encode.attrs.get("receptive_field", None):
|
||||
field_size = encode.attrs["receptive_field"]
|
||||
with Model.define_operators({">>": chain, "|": concatenate}):
|
||||
if extract.has_dim("nO"):
|
||||
_set_dims(embed, "nI", extract.get_dim("nO"))
|
||||
tok2vec = extract >> with_array(embed >> encode, pad=field_size)
|
||||
tok2vec.set_dim("nO", encode.get_dim("nO"))
|
||||
tok2vec.set_ref("embed", embed)
|
||||
|
@ -176,18 +174,11 @@ def MultiHashEmbed(columns, width, rows, use_subwords, pretrained_vectors, mix):
|
|||
nr_columns = 2
|
||||
concat_columns = glove | norm
|
||||
|
||||
_set_dims(mix, "nI", width * nr_columns)
|
||||
embed_layer = uniqued(concat_columns >> mix, column=columns.index("ORTH"))
|
||||
|
||||
return embed_layer
|
||||
|
||||
|
||||
def _set_dims(model, name, value):
|
||||
# Loop through the model to set a specific dimension if its unset on any layer.
|
||||
for node in model.walk():
|
||||
if node.has_dim(name) is None:
|
||||
node.set_dim(name, value)
|
||||
|
||||
@registry.architectures.register("spacy.CharacterEmbed.v1")
|
||||
def CharacterEmbed(columns, width, rows, nM, nC, features):
|
||||
norm = HashEmbed(nO=width, nV=rows, column=columns.index("NORM"))
|
||||
|
@ -344,6 +335,7 @@ def build_Tok2Vec_model(
|
|||
tok2vec = tok2vec >> PyTorchLSTM(
|
||||
nO=width, nI=width, depth=bilstm_depth, bi=True
|
||||
)
|
||||
tok2vec.set_dim("nO", width)
|
||||
if tok2vec.has_dim("nO") is not False:
|
||||
tok2vec.set_dim("nO", width)
|
||||
tok2vec.set_ref("embed", embed)
|
||||
return tok2vec
|
||||
|
|
27
spacy/ml/spacy_vectors.py
Normal file
27
spacy/ml/spacy_vectors.py
Normal file
|
@ -0,0 +1,27 @@
|
|||
import numpy
|
||||
from thinc.api import Model, Unserializable
|
||||
|
||||
|
||||
def SpacyVectors(vectors) -> Model:
|
||||
attrs = {"vectors": Unserializable(vectors)}
|
||||
model = Model("spacy_vectors", forward, attrs=attrs)
|
||||
return model
|
||||
|
||||
|
||||
def forward(model, docs, is_train: bool):
|
||||
batch = []
|
||||
vectors = model.attrs["vectors"].obj
|
||||
for doc in docs:
|
||||
indices = numpy.zeros((len(doc),), dtype="i")
|
||||
for i, word in enumerate(doc):
|
||||
if word.orth in vectors.key2row:
|
||||
indices[i] = vectors.key2row[word.orth]
|
||||
else:
|
||||
indices[i] = 0
|
||||
batch_vectors = vectors.data[indices]
|
||||
batch.append(batch_vectors)
|
||||
|
||||
def backprop(dY):
|
||||
return None
|
||||
|
||||
return batch, backprop
|
|
@ -148,7 +148,8 @@ class Pipe(object):
|
|||
return sgd
|
||||
|
||||
def set_output(self, nO):
|
||||
self.model.set_dim("nO", nO)
|
||||
if self.model.has_dim("nO") is not False:
|
||||
self.model.set_dim("nO", nO)
|
||||
if self.model.has_ref("output_layer"):
|
||||
self.model.get_ref("output_layer").set_dim("nO", nO)
|
||||
|
||||
|
@ -1133,6 +1134,7 @@ class TextCategorizer(Pipe):
|
|||
docs = [Doc(Vocab(), words=["hello"])]
|
||||
truths, _ = self._examples_to_truth(examples)
|
||||
self.set_output(len(self.labels))
|
||||
link_vectors_to_models(self.vocab)
|
||||
self.model.initialize(X=docs, Y=truths)
|
||||
if sgd is None:
|
||||
sgd = self.create_optimizer()
|
||||
|
|
|
@ -131,10 +131,8 @@ class Tok2Vec(Pipe):
|
|||
get_examples (function): Function returning example training data.
|
||||
pipeline (list): The pipeline the model is part of.
|
||||
"""
|
||||
# TODO: charembed does not play nicely with dim inference yet
|
||||
# docs = [Doc(Vocab(), words=["hello"])]
|
||||
# self.model.initialize(X=docs)
|
||||
self.model.initialize()
|
||||
docs = [Doc(Vocab(), words=["hello"])]
|
||||
self.model.initialize(X=docs)
|
||||
link_vectors_to_models(self.vocab)
|
||||
|
||||
|
||||
|
|
|
@ -6,10 +6,12 @@ from spacy import util
|
|||
from spacy.lang.en import English
|
||||
from spacy.language import Language
|
||||
from spacy.pipeline import TextCategorizer
|
||||
from spacy.tests.util import make_tempdir
|
||||
from spacy.tokens import Doc
|
||||
from spacy.gold import GoldParse
|
||||
|
||||
from ..util import make_tempdir
|
||||
from ...ml.models.defaults import default_tok2vec
|
||||
|
||||
TRAIN_DATA = [
|
||||
("I'm so happy.", {"cats": {"POSITIVE": 1.0, "NEGATIVE": 0.0}}),
|
||||
("I'm so angry", {"cats": {"POSITIVE": 0.0, "NEGATIVE": 1.0}}),
|
||||
|
@ -109,3 +111,33 @@ def test_overfitting_IO():
|
|||
cats2 = doc2.cats
|
||||
assert cats2["POSITIVE"] > 0.9
|
||||
assert cats2["POSITIVE"] + cats2["NEGATIVE"] == pytest.approx(1.0, 0.1)
|
||||
|
||||
|
||||
# fmt: off
|
||||
@pytest.mark.parametrize(
|
||||
"textcat_config",
|
||||
[
|
||||
{"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": False, "ngram_size": 1, "no_output_layer": False},
|
||||
{"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": True, "ngram_size": 4, "no_output_layer": False},
|
||||
{"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": False, "ngram_size": 3, "no_output_layer": True},
|
||||
{"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": True, "ngram_size": 2, "no_output_layer": True},
|
||||
{"@architectures": "spacy.TextCat.v1", "exclusive_classes": False, "ngram_size": 1, "pretrained_vectors": False, "width": 64, "conv_depth": 2, "embed_size": 2000, "window_size": 2},
|
||||
{"@architectures": "spacy.TextCat.v1", "exclusive_classes": True, "ngram_size": 5, "pretrained_vectors": False, "width": 128, "conv_depth": 2, "embed_size": 2000, "window_size": 1},
|
||||
{"@architectures": "spacy.TextCat.v1", "exclusive_classes": True, "ngram_size": 2, "pretrained_vectors": False, "width": 32, "conv_depth": 3, "embed_size": 500, "window_size": 3},
|
||||
{"@architectures": "spacy.TextCatCNN.v1", "tok2vec": default_tok2vec(), "exclusive_classes": True},
|
||||
{"@architectures": "spacy.TextCatCNN.v1", "tok2vec": default_tok2vec(), "exclusive_classes": False},
|
||||
],
|
||||
)
|
||||
# fmt: on
|
||||
def test_textcat_configs(textcat_config):
|
||||
pipe_config = {"model": textcat_config}
|
||||
nlp = English()
|
||||
textcat = nlp.create_pipe("textcat", pipe_config)
|
||||
for _, annotations in TRAIN_DATA:
|
||||
for label, value in annotations.get("cats").items():
|
||||
textcat.add_label(label)
|
||||
nlp.add_pipe(textcat)
|
||||
optimizer = nlp.begin_training()
|
||||
for i in range(5):
|
||||
losses = {}
|
||||
nlp.update(TRAIN_DATA, sgd=optimizer, losses=losses)
|
||||
|
|
|
@ -4,8 +4,7 @@ import ctypes
|
|||
from pathlib import Path
|
||||
from spacy import util
|
||||
from spacy import prefer_gpu, require_gpu
|
||||
from spacy.ml._layers import PrecomputableAffine
|
||||
from spacy.ml._layers import _backprop_precomputable_affine_padding
|
||||
from spacy.ml._precomputable_affine import PrecomputableAffine, _backprop_precomputable_affine_padding
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
|
@ -4,18 +4,7 @@ from spacy.ml.models.tok2vec import build_Tok2Vec_model
|
|||
from spacy.vocab import Vocab
|
||||
from spacy.tokens import Doc
|
||||
|
||||
|
||||
def get_batch(batch_size):
|
||||
vocab = Vocab()
|
||||
docs = []
|
||||
start = 0
|
||||
for size in range(1, batch_size + 1):
|
||||
# Make the words numbers, so that they're distinct
|
||||
# across the batch, and easy to track.
|
||||
numbers = [str(i) for i in range(start, start + size)]
|
||||
docs.append(Doc(vocab, words=numbers))
|
||||
start += size
|
||||
return docs
|
||||
from .util import get_batch
|
||||
|
||||
|
||||
# This fails in Thinc v7.3.1. Need to push patch
|
||||
|
@ -75,7 +64,7 @@ def test_tok2vec_batch_sizes(batch_size, width, embed_size):
|
|||
def test_tok2vec_configs(tok2vec_config):
|
||||
docs = get_batch(3)
|
||||
tok2vec = build_Tok2Vec_model(**tok2vec_config)
|
||||
tok2vec.initialize()
|
||||
tok2vec.initialize(docs)
|
||||
vectors, backprop = tok2vec.begin_update(docs)
|
||||
assert len(vectors) == len(docs)
|
||||
assert vectors[0].shape == (len(docs[0]), tok2vec_config["width"])
|
||||
|
|
|
@ -9,6 +9,8 @@ from spacy import Errors
|
|||
from spacy.tokens import Doc, Span
|
||||
from spacy.attrs import POS, TAG, HEAD, DEP, LEMMA
|
||||
|
||||
from spacy.vocab import Vocab
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def make_tempfile(mode="r"):
|
||||
|
@ -77,6 +79,19 @@ def get_doc(
|
|||
return doc
|
||||
|
||||
|
||||
def get_batch(batch_size):
|
||||
vocab = Vocab()
|
||||
docs = []
|
||||
start = 0
|
||||
for size in range(1, batch_size + 1):
|
||||
# Make the words numbers, so that they're distinct
|
||||
# across the batch, and easy to track.
|
||||
numbers = [str(i) for i in range(start, start + size)]
|
||||
docs.append(Doc(vocab, words=numbers))
|
||||
start += size
|
||||
return docs
|
||||
|
||||
|
||||
def apply_transition_sequence(parser, doc, sequence):
|
||||
"""Perform a series of pre-specified transitions, to put the parser in a
|
||||
desired state."""
|
||||
|
|
Loading…
Reference in New Issue
Block a user