mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 07:57:35 +03:00 
			
		
		
		
	Train textcat with config (#5143)
* bring back default build_text_classifier method * remove _set_dims_ hack in favor of proper dim inference * add tok2vec initialize to unit test * small fixes * add unit test for various textcat config settings * logistic output layer does not have nO * fix window_size setting * proper fix * fix W initialization * Update textcat training example * Use ml_datasets * Convert training data to `Example` format * Use `n_texts` to set proportionate dev size * fix _init renaming on latest thinc * avoid setting a non-existing dim * update to thinc==8.0.0a2 * add BOW and CNN defaults for easy testing * various experiments with train_textcat script, fix softmax activation in textcat bow * allow textcat train script to work on other datasets as well * have dataset as a parameter * train textcat from config, with example config * add config for training textcat * formatting * fix exclusive_classes * fixing BOW for GPU * bump thinc to 8.0.0a3 (not published yet so CI will fail) * add in link_vectors_to_models which got deleted Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>
This commit is contained in:
		
							parent
							
								
									ce0e538068
								
							
						
					
					
						commit
						311133e579
					
				|  | @ -2,70 +2,71 @@ | |||
| # coding: utf8 | ||||
| """Train a convolutional neural network text classifier on the | ||||
| IMDB dataset, using the TextCategorizer component. The dataset will be loaded | ||||
| automatically via Thinc's built-in dataset loader. The model is added to | ||||
| automatically via the package `ml_datasets`. The model is added to | ||||
| spacy.pipeline, and predictions are available via `doc.cats`. For more details, | ||||
| see the documentation: | ||||
| * Training: https://spacy.io/usage/training | ||||
| 
 | ||||
| Compatible with: spaCy v2.0.0+ | ||||
| Compatible with: spaCy v3.0.0+ | ||||
| """ | ||||
| from __future__ import unicode_literals, print_function | ||||
| 
 | ||||
| import ml_datasets | ||||
| import plac | ||||
| import random | ||||
| from pathlib import Path | ||||
| from ml_datasets import loaders | ||||
| 
 | ||||
| import spacy | ||||
| from spacy import util | ||||
| from spacy.util import minibatch, compounding | ||||
| from spacy.gold import Example, GoldParse | ||||
| 
 | ||||
| 
 | ||||
| @plac.annotations( | ||||
|     model=("Model name. Defaults to blank 'en' model.", "option", "m", str), | ||||
|     config_path=("Path to config file", "positional", None, Path), | ||||
|     output_dir=("Optional output directory", "option", "o", Path), | ||||
|     n_texts=("Number of texts to train from", "option", "t", int), | ||||
|     n_iter=("Number of training iterations", "option", "n", int), | ||||
|     init_tok2vec=("Pretrained tok2vec weights", "option", "t2v", Path), | ||||
|     dataset=("Dataset to train on (default: imdb)", "option", "d", str), | ||||
|     threshold=("Min. number of instances for a given label (default 20)", "option", "m", int) | ||||
| ) | ||||
| def main(model=None, output_dir=None, n_iter=20, n_texts=2000, init_tok2vec=None): | ||||
| def main(config_path, output_dir=None, n_iter=20, n_texts=2000, init_tok2vec=None, dataset="imdb", threshold=20): | ||||
|     if not config_path or not config_path.exists(): | ||||
|         raise ValueError(f"Config file not found at {config_path}") | ||||
| 
 | ||||
|     spacy.util.fix_random_seed() | ||||
|     if output_dir is not None: | ||||
|         output_dir = Path(output_dir) | ||||
|         if not output_dir.exists(): | ||||
|             output_dir.mkdir() | ||||
| 
 | ||||
|     if model is not None: | ||||
|         nlp = spacy.load(model)  # load existing spaCy model | ||||
|         print("Loaded model '%s'" % model) | ||||
|     else: | ||||
|         nlp = spacy.blank("en")  # create blank Language class | ||||
|         print("Created blank 'en' model") | ||||
|     print(f"Loading nlp model from {config_path}") | ||||
|     nlp_config = util.load_config(config_path, create_objects=False)["nlp"] | ||||
|     nlp = util.load_model_from_config(nlp_config) | ||||
| 
 | ||||
|     # add the text classifier to the pipeline if it doesn't exist | ||||
|     # nlp.create_pipe works for built-ins that are registered with spaCy | ||||
|     # ensure the nlp object was defined with a textcat component | ||||
|     if "textcat" not in nlp.pipe_names: | ||||
|         textcat = nlp.create_pipe( | ||||
|             "textcat", config={"exclusive_classes": True, "architecture": "simple_cnn"} | ||||
|         ) | ||||
|         nlp.add_pipe(textcat, last=True) | ||||
|     # otherwise, get it, so we can add labels to it | ||||
|     else: | ||||
|         textcat = nlp.get_pipe("textcat") | ||||
|         raise ValueError(f"The nlp definition in the config does not contain a textcat component") | ||||
| 
 | ||||
|     # add label to text classifier | ||||
|     textcat.add_label("POSITIVE") | ||||
|     textcat.add_label("NEGATIVE") | ||||
|     textcat = nlp.get_pipe("textcat") | ||||
| 
 | ||||
|     # load the IMDB dataset | ||||
|     print("Loading IMDB data...") | ||||
|     (train_texts, train_cats), (dev_texts, dev_cats) = load_data() | ||||
|     train_texts = train_texts[:n_texts] | ||||
|     train_cats = train_cats[:n_texts] | ||||
|     # load the dataset | ||||
|     print(f"Loading dataset {dataset} ...") | ||||
|     (train_texts, train_cats), (dev_texts, dev_cats) = load_data(dataset=dataset, threshold=threshold, limit=n_texts) | ||||
|     print( | ||||
|         "Using {} examples ({} training, {} evaluation)".format( | ||||
|             n_texts, len(train_texts), len(dev_texts) | ||||
|         ) | ||||
|     ) | ||||
|     train_data = list(zip(train_texts, [{"cats": cats} for cats in train_cats])) | ||||
|     train_examples = [] | ||||
|     for text, cats in zip(train_texts, train_cats): | ||||
|         doc = nlp.make_doc(text) | ||||
|         gold = GoldParse(doc, cats=cats) | ||||
|         for cat in cats: | ||||
|             textcat.add_label(cat) | ||||
|         ex = Example.from_gold(gold, doc=doc) | ||||
|         train_examples.append(ex) | ||||
| 
 | ||||
|     # get names of other pipes to disable them during training | ||||
|     pipe_exceptions = ["textcat", "trf_wordpiecer", "trf_tok2vec"] | ||||
|  | @ -81,8 +82,8 @@ def main(model=None, output_dir=None, n_iter=20, n_texts=2000, init_tok2vec=None | |||
|         for i in range(n_iter): | ||||
|             losses = {} | ||||
|             # batch up the examples using spaCy's minibatch | ||||
|             random.shuffle(train_data) | ||||
|             batches = minibatch(train_data, size=batch_sizes) | ||||
|             random.shuffle(train_examples) | ||||
|             batches = minibatch(train_examples, size=batch_sizes) | ||||
|             for batch in batches: | ||||
|                 nlp.update(batch, sgd=optimizer, drop=0.2, losses=losses) | ||||
|             with textcat.model.use_params(optimizer.averages): | ||||
|  | @ -97,7 +98,7 @@ def main(model=None, output_dir=None, n_iter=20, n_texts=2000, init_tok2vec=None | |||
|                 ) | ||||
|             ) | ||||
| 
 | ||||
|     # test the trained model | ||||
|     # test the trained model (only makes sense for sentiment analysis) | ||||
|     test_text = "This movie sucked" | ||||
|     doc = nlp(test_text) | ||||
|     print(test_text, doc.cats) | ||||
|  | @ -114,14 +115,39 @@ def main(model=None, output_dir=None, n_iter=20, n_texts=2000, init_tok2vec=None | |||
|         print(test_text, doc2.cats) | ||||
| 
 | ||||
| 
 | ||||
| def load_data(limit=0, split=0.8): | ||||
|     """Load data from the IMDB dataset.""" | ||||
| def load_data(dataset, threshold, limit=0, split=0.8): | ||||
|     """Load data from the provided dataset.""" | ||||
|     # Partition off part of the train data for evaluation | ||||
|     train_data, _ = ml_datasets.imdb() | ||||
|     data_loader = loaders.get(dataset) | ||||
|     train_data, _ = data_loader(limit=int(limit/split)) | ||||
|     random.shuffle(train_data) | ||||
|     train_data = train_data[-limit:] | ||||
|     texts, labels = zip(*train_data) | ||||
|     cats = [{"POSITIVE": bool(y), "NEGATIVE": not bool(y)} for y in labels] | ||||
| 
 | ||||
|     unique_labels = sorted(set([l for label_set in labels for l in label_set])) | ||||
|     print(f"# of unique_labels: {len(unique_labels)}") | ||||
| 
 | ||||
|     count_values_train = dict() | ||||
|     for text, annot_list in train_data: | ||||
|         for annot in annot_list: | ||||
|             count_values_train[annot] = count_values_train.get(annot, 0) + 1 | ||||
|     for value, count in sorted(count_values_train.items(), key=lambda item: item[1]): | ||||
|         if count < threshold: | ||||
|             unique_labels.remove(value) | ||||
| 
 | ||||
|     print(f"# of unique_labels after filtering with threshold {threshold}: {len(unique_labels)}") | ||||
| 
 | ||||
|     if unique_labels == {0, 1}: | ||||
|         cats = [{"POSITIVE": bool(y), "NEGATIVE": not bool(y)} for y in labels] | ||||
|     else: | ||||
|         cats = [] | ||||
|         for y in labels: | ||||
|             if isinstance(y, str): | ||||
|                 cats.append({str(label): (label == y) for label in unique_labels}) | ||||
|             elif isinstance(y, set): | ||||
|                 cats.append({str(label): (label in y) for label in unique_labels}) | ||||
|             else: | ||||
|                 raise ValueError(f"Unrecognised type of labels: {type(y)}") | ||||
| 
 | ||||
|     split = int(len(train_data) * split) | ||||
|     return (texts[:split], cats[:split]), (texts[split:], cats[split:]) | ||||
| 
 | ||||
|  |  | |||
							
								
								
									
										19
									
								
								examples/training/train_textcat_config.cfg
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								examples/training/train_textcat_config.cfg
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,19 @@ | |||
| [nlp] | ||||
| lang = "en" | ||||
| 
 | ||||
| [nlp.pipeline.textcat] | ||||
| factory = "textcat" | ||||
| 
 | ||||
| [nlp.pipeline.textcat.model] | ||||
| @architectures = "spacy.TextCatCNN.v1" | ||||
| exclusive_classes = false | ||||
| 
 | ||||
| [nlp.pipeline.textcat.model.tok2vec] | ||||
| @architectures = "spacy.HashEmbedCNN.v1" | ||||
| pretrained_vectors = null | ||||
| width = 96 | ||||
| depth = 4 | ||||
| embed_size = 2000 | ||||
| window_size = 1 | ||||
| maxout_pieces = 3 | ||||
| subword_features = true | ||||
|  | @ -11,26 +11,26 @@ def extract_ngrams(ngram_size, attr=LOWER) -> Model: | |||
|     return model | ||||
| 
 | ||||
| 
 | ||||
| def forward(self, docs, is_train: bool): | ||||
| def forward(model, docs, is_train: bool): | ||||
|     batch_keys = [] | ||||
|     batch_vals = [] | ||||
|     for doc in docs: | ||||
|         unigrams = doc.to_array([self.attrs["attr"]]) | ||||
|         unigrams = model.ops.asarray(doc.to_array([model.attrs["attr"]])) | ||||
|         ngrams = [unigrams] | ||||
|         for n in range(2, self.attrs["ngram_size"] + 1): | ||||
|             ngrams.append(self.ops.ngrams(n, unigrams)) | ||||
|         keys = self.ops.xp.concatenate(ngrams) | ||||
|         keys, vals = self.ops.xp.unique(keys, return_counts=True) | ||||
|         for n in range(2, model.attrs["ngram_size"] + 1): | ||||
|             ngrams.append(model.ops.ngrams(n, unigrams)) | ||||
|         keys = model.ops.xp.concatenate(ngrams) | ||||
|         keys, vals = model.ops.xp.unique(keys, return_counts=True) | ||||
|         batch_keys.append(keys) | ||||
|         batch_vals.append(vals) | ||||
|     # The dtype here matches what thinc is expecting -- which differs per | ||||
|     # platform (by int definition). This should be fixed once the problem | ||||
|     # is fixed on Thinc's side. | ||||
|     lengths = self.ops.asarray([arr.shape[0] for arr in batch_keys], dtype=numpy.int_) | ||||
|     batch_keys = self.ops.xp.concatenate(batch_keys) | ||||
|     batch_vals = self.ops.asarray(self.ops.xp.concatenate(batch_vals), dtype="f") | ||||
|     lengths = model.ops.asarray([arr.shape[0] for arr in batch_keys], dtype=numpy.int_) | ||||
|     batch_keys = model.ops.xp.concatenate(batch_keys) | ||||
|     batch_vals = model.ops.asarray(model.ops.xp.concatenate(batch_vals), dtype="f") | ||||
| 
 | ||||
|     def backprop(dY): | ||||
|         return dY | ||||
|         return [] | ||||
| 
 | ||||
|     return (batch_keys, batch_vals, lengths), backprop | ||||
|  |  | |||
							
								
								
									
										5
									
								
								spacy/ml/models/defaults/textcat_bow_defaults.cfg
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								spacy/ml/models/defaults/textcat_bow_defaults.cfg
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,5 @@ | |||
| [model] | ||||
| @architectures = "spacy.TextCatBOW.v1" | ||||
| exclusive_classes = false | ||||
| ngram_size: 1 | ||||
| no_output_layer: false | ||||
							
								
								
									
										13
									
								
								spacy/ml/models/defaults/textcat_cnn_defaults.cfg
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								spacy/ml/models/defaults/textcat_cnn_defaults.cfg
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,13 @@ | |||
| [model] | ||||
| @architectures = "spacy.TextCatCNN.v1" | ||||
| exclusive_classes = false | ||||
| 
 | ||||
| [model.tok2vec] | ||||
| @architectures = "spacy.HashEmbedCNN.v1" | ||||
| pretrained_vectors = null | ||||
| width = 96 | ||||
| depth = 4 | ||||
| embed_size = 2000 | ||||
| window_size = 1 | ||||
| maxout_pieces = 3 | ||||
| subword_features = true | ||||
|  | @ -1,13 +1,9 @@ | |||
| [model] | ||||
| @architectures = "spacy.TextCatCNN.v1" | ||||
| @architectures = "spacy.TextCat.v1" | ||||
| exclusive_classes = false | ||||
| 
 | ||||
| [model.tok2vec] | ||||
| @architectures = "spacy.HashEmbedCNN.v1" | ||||
| pretrained_vectors = null | ||||
| width = 96 | ||||
| depth = 4 | ||||
| width = 64 | ||||
| conv_depth = 2 | ||||
| embed_size = 2000 | ||||
| window_size = 1 | ||||
| maxout_pieces = 3 | ||||
| subword_features = true | ||||
| ngram_size = 1 | ||||
|  |  | |||
|  | @ -2,7 +2,7 @@ from pydantic import StrictInt | |||
| from thinc.api import Model, chain, list2array, Linear, zero_init, use_ops | ||||
| 
 | ||||
| from ...util import registry | ||||
| from .._layers import PrecomputableAffine | ||||
| from .._precomputable_affine import PrecomputableAffine | ||||
| from ...syntax._parser_model import ParserModel | ||||
| 
 | ||||
| 
 | ||||
|  |  | |||
|  | @ -1,7 +1,11 @@ | |||
| from thinc.api import Model, chain, reduce_mean, Linear, list2ragged, Logistic | ||||
| from thinc.api import SparseLinear, Softmax | ||||
| from thinc.api import Model, reduce_mean, Linear, list2ragged, Logistic, ParametricAttention | ||||
| from thinc.api import chain, concatenate, clone, Dropout | ||||
| from thinc.api import SparseLinear, Softmax, softmax_activation, Maxout, reduce_sum, Relu, residual, expand_window | ||||
| from thinc.api import HashEmbed, with_ragged, with_array, with_cpu, uniqued, FeatureExtractor | ||||
| 
 | ||||
| from ...attrs import ORTH | ||||
| from ..spacy_vectors import SpacyVectors | ||||
| from ... import util | ||||
| from ...attrs import ID, ORTH, NORM, PREFIX, SUFFIX, SHAPE, LOWER | ||||
| from ...util import registry | ||||
| from ..extract_ngrams import extract_ngrams | ||||
| 
 | ||||
|  | @ -20,7 +24,6 @@ def build_simple_cnn_text_classifier(tok2vec, exclusive_classes, nO=None): | |||
|             model = tok2vec >> list2ragged() >> reduce_mean() >> output_layer | ||||
|             model.set_ref("output_layer", output_layer) | ||||
|         else: | ||||
|             # TODO: experiment with init_w=zero_init | ||||
|             linear_layer = Linear(nO=nO, nI=tok2vec.get_dim("nO")) | ||||
|             model = ( | ||||
|                 tok2vec >> list2ragged() >> reduce_mean() >> linear_layer >> Logistic() | ||||
|  | @ -33,13 +36,100 @@ def build_simple_cnn_text_classifier(tok2vec, exclusive_classes, nO=None): | |||
| 
 | ||||
| @registry.architectures.register("spacy.TextCatBOW.v1") | ||||
| def build_bow_text_classifier(exclusive_classes, ngram_size, no_output_layer, nO=None): | ||||
|     # Note: original defaults were ngram_size=1 and no_output_layer=False | ||||
|     with Model.define_operators({">>": chain}): | ||||
|         model = extract_ngrams(ngram_size, attr=ORTH) >> SparseLinear(nO) | ||||
|         model.to_cpu() | ||||
|         sparse_linear = SparseLinear(nO) | ||||
|         model = extract_ngrams(ngram_size, attr=ORTH) >> sparse_linear | ||||
|         model = with_cpu(model, model.ops) | ||||
|         if not no_output_layer: | ||||
|             output_layer = Softmax(nO) if exclusive_classes else Logistic(nO) | ||||
|             output_layer.to_cpu() | ||||
|             model = model >> output_layer | ||||
|             model.set_ref("output_layer", output_layer) | ||||
|             output_layer = softmax_activation() if exclusive_classes else Logistic() | ||||
|             model = model >> with_cpu(output_layer, output_layer.ops) | ||||
|     model.set_ref("output_layer", sparse_linear) | ||||
|     return model | ||||
| 
 | ||||
| 
 | ||||
| @registry.architectures.register("spacy.TextCat.v1") | ||||
| def build_text_classifier(width, embed_size, pretrained_vectors, exclusive_classes, ngram_size, | ||||
|                           window_size, conv_depth, nO=None): | ||||
|     cols = [ORTH, LOWER, PREFIX, SUFFIX, SHAPE, ID] | ||||
|     with Model.define_operators({">>": chain, "|": concatenate, "**": clone}): | ||||
|         lower = HashEmbed(nO=width, nV=embed_size, column=cols.index(LOWER)) | ||||
|         prefix = HashEmbed(nO=width // 2, nV=embed_size, column=cols.index(PREFIX)) | ||||
|         suffix = HashEmbed(nO=width // 2, nV=embed_size, column=cols.index(SUFFIX)) | ||||
|         shape = HashEmbed(nO=width // 2, nV=embed_size, column=cols.index(SHAPE)) | ||||
| 
 | ||||
|         width_nI = sum(layer.get_dim("nO") for layer in [lower, prefix, suffix, shape]) | ||||
|         trained_vectors = FeatureExtractor(cols) >> with_array( | ||||
|             uniqued( | ||||
|                 (lower | prefix | suffix | shape) | ||||
|                 >> Maxout(nO=width, nI=width_nI, normalize=True), | ||||
|                 column=cols.index(ORTH), | ||||
|             ) | ||||
|         ) | ||||
| 
 | ||||
|         if pretrained_vectors: | ||||
|             nlp = util.load_model(pretrained_vectors) | ||||
|             vectors = nlp.vocab.vectors | ||||
|             vector_dim = vectors.data.shape[1] | ||||
| 
 | ||||
|             static_vectors = SpacyVectors(vectors) >> with_array( | ||||
|                 Linear(width, vector_dim) | ||||
|             ) | ||||
|             vector_layer = trained_vectors | static_vectors | ||||
|             vectors_width = width * 2 | ||||
|         else: | ||||
|             vector_layer = trained_vectors | ||||
|             vectors_width = width | ||||
|         tok2vec = vector_layer >> with_array( | ||||
|             Maxout(width, vectors_width, normalize=True) | ||||
|             >> residual((expand_window(window_size=window_size) | ||||
|                          >> Maxout(nO=width, nI=width * ((window_size * 2) + 1), normalize=True))) ** conv_depth, | ||||
|             pad=conv_depth, | ||||
|         ) | ||||
|         cnn_model = ( | ||||
|                 tok2vec | ||||
|                 >> list2ragged() | ||||
|                 >> ParametricAttention(width) | ||||
|                 >> reduce_sum() | ||||
|                 >> residual(Maxout(nO=width, nI=width)) | ||||
|                 >> Linear(nO=nO, nI=width) | ||||
|                 >> Dropout(0.0) | ||||
|         ) | ||||
| 
 | ||||
|         linear_model = build_bow_text_classifier( | ||||
|             nO=nO, ngram_size=ngram_size, exclusive_classes=exclusive_classes, no_output_layer=False | ||||
|         ) | ||||
|         nO_double = nO*2 if nO else None | ||||
|         if exclusive_classes: | ||||
|             output_layer = Softmax(nO=nO, nI=nO_double) | ||||
|         else: | ||||
|             output_layer = ( | ||||
|                     Linear(nO=nO, nI=nO_double) >> Dropout(0.0) >> Logistic() | ||||
|             ) | ||||
|         model = (linear_model | cnn_model) >> output_layer | ||||
|         model.set_ref("tok2vec", tok2vec) | ||||
|     if model.has_dim("nO") is not False: | ||||
|         model.set_dim("nO", nO) | ||||
|     model.set_ref("output_layer", linear_model.get_ref("output_layer")) | ||||
|     return model | ||||
| 
 | ||||
| 
 | ||||
| @registry.architectures.register("spacy.TextCatLowData.v1") | ||||
| def build_text_classifier_lowdata(width, pretrained_vectors, nO=None): | ||||
|     nlp = util.load_model(pretrained_vectors) | ||||
|     vectors = nlp.vocab.vectors | ||||
|     vector_dim = vectors.data.shape[1] | ||||
| 
 | ||||
|     # Note, before v.3, this was the default if setting "low_data" and "pretrained_dims" | ||||
|     with Model.define_operators({">>": chain, "**": clone}): | ||||
|         model = ( | ||||
|             SpacyVectors(vectors) | ||||
|             >> list2ragged() | ||||
|             >> with_ragged(0, Linear(width, vector_dim)) | ||||
|             >> ParametricAttention(width) | ||||
|             >> reduce_sum() | ||||
|             >> residual(Relu(width, width)) ** 2 | ||||
|             >> Linear(nO, width) | ||||
|             >> Dropout(0.0) | ||||
|             >> Logistic() | ||||
|         ) | ||||
|     return model | ||||
|  |  | |||
|  | @ -28,8 +28,6 @@ def Tok2Vec(extract, embed, encode): | |||
|     if encode.attrs.get("receptive_field", None): | ||||
|         field_size = encode.attrs["receptive_field"] | ||||
|     with Model.define_operators({">>": chain, "|": concatenate}): | ||||
|         if extract.has_dim("nO"): | ||||
|             _set_dims(embed, "nI", extract.get_dim("nO")) | ||||
|         tok2vec = extract >> with_array(embed >> encode, pad=field_size) | ||||
|     tok2vec.set_dim("nO", encode.get_dim("nO")) | ||||
|     tok2vec.set_ref("embed", embed) | ||||
|  | @ -176,18 +174,11 @@ def MultiHashEmbed(columns, width, rows, use_subwords, pretrained_vectors, mix): | |||
|                 nr_columns = 2 | ||||
|                 concat_columns = glove | norm | ||||
| 
 | ||||
|             _set_dims(mix, "nI", width * nr_columns) | ||||
|             embed_layer = uniqued(concat_columns >> mix, column=columns.index("ORTH")) | ||||
| 
 | ||||
|     return embed_layer | ||||
| 
 | ||||
| 
 | ||||
| def _set_dims(model, name, value): | ||||
|     # Loop through the model to set a specific dimension if its unset on any layer. | ||||
|     for node in model.walk(): | ||||
|         if node.has_dim(name) is None: | ||||
|             node.set_dim(name, value) | ||||
| 
 | ||||
| @registry.architectures.register("spacy.CharacterEmbed.v1") | ||||
| def CharacterEmbed(columns, width, rows, nM, nC, features): | ||||
|     norm = HashEmbed(nO=width, nV=rows, column=columns.index("NORM")) | ||||
|  | @ -344,6 +335,7 @@ def build_Tok2Vec_model( | |||
|             tok2vec = tok2vec >> PyTorchLSTM( | ||||
|                 nO=width, nI=width, depth=bilstm_depth, bi=True | ||||
|             ) | ||||
|         tok2vec.set_dim("nO", width) | ||||
|         if tok2vec.has_dim("nO") is not False: | ||||
|             tok2vec.set_dim("nO", width) | ||||
|         tok2vec.set_ref("embed", embed) | ||||
|     return tok2vec | ||||
|  |  | |||
							
								
								
									
										27
									
								
								spacy/ml/spacy_vectors.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										27
									
								
								spacy/ml/spacy_vectors.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,27 @@ | |||
| import numpy | ||||
| from thinc.api import Model, Unserializable | ||||
| 
 | ||||
| 
 | ||||
| def SpacyVectors(vectors) -> Model: | ||||
|     attrs = {"vectors": Unserializable(vectors)} | ||||
|     model = Model("spacy_vectors", forward, attrs=attrs) | ||||
|     return model | ||||
| 
 | ||||
| 
 | ||||
| def forward(model, docs, is_train: bool): | ||||
|     batch = [] | ||||
|     vectors = model.attrs["vectors"].obj | ||||
|     for doc in docs: | ||||
|         indices = numpy.zeros((len(doc),), dtype="i") | ||||
|         for i, word in enumerate(doc): | ||||
|             if word.orth in vectors.key2row: | ||||
|                 indices[i] = vectors.key2row[word.orth] | ||||
|             else: | ||||
|                 indices[i] = 0 | ||||
|         batch_vectors = vectors.data[indices] | ||||
|         batch.append(batch_vectors) | ||||
| 
 | ||||
|         def backprop(dY): | ||||
|             return None | ||||
| 
 | ||||
|     return batch, backprop | ||||
|  | @ -148,7 +148,8 @@ class Pipe(object): | |||
|         return sgd | ||||
| 
 | ||||
|     def set_output(self, nO): | ||||
|         self.model.set_dim("nO", nO) | ||||
|         if self.model.has_dim("nO") is not False: | ||||
|             self.model.set_dim("nO", nO) | ||||
|         if self.model.has_ref("output_layer"): | ||||
|             self.model.get_ref("output_layer").set_dim("nO", nO) | ||||
| 
 | ||||
|  | @ -1133,6 +1134,7 @@ class TextCategorizer(Pipe): | |||
|         docs = [Doc(Vocab(), words=["hello"])] | ||||
|         truths, _ = self._examples_to_truth(examples) | ||||
|         self.set_output(len(self.labels)) | ||||
|         link_vectors_to_models(self.vocab) | ||||
|         self.model.initialize(X=docs, Y=truths) | ||||
|         if sgd is None: | ||||
|             sgd = self.create_optimizer() | ||||
|  |  | |||
|  | @ -131,10 +131,8 @@ class Tok2Vec(Pipe): | |||
|         get_examples (function): Function returning example training data. | ||||
|         pipeline (list): The pipeline the model is part of. | ||||
|         """ | ||||
|         # TODO: charembed does not play nicely with dim inference yet | ||||
|         # docs = [Doc(Vocab(), words=["hello"])] | ||||
|         # self.model.initialize(X=docs) | ||||
|         self.model.initialize() | ||||
|         docs = [Doc(Vocab(), words=["hello"])] | ||||
|         self.model.initialize(X=docs) | ||||
|         link_vectors_to_models(self.vocab) | ||||
| 
 | ||||
| 
 | ||||
|  |  | |||
|  | @ -6,10 +6,12 @@ from spacy import util | |||
| from spacy.lang.en import English | ||||
| from spacy.language import Language | ||||
| from spacy.pipeline import TextCategorizer | ||||
| from spacy.tests.util import make_tempdir | ||||
| from spacy.tokens import Doc | ||||
| from spacy.gold import GoldParse | ||||
| 
 | ||||
| from ..util import make_tempdir | ||||
| from ...ml.models.defaults import default_tok2vec | ||||
| 
 | ||||
| TRAIN_DATA = [ | ||||
|     ("I'm so happy.", {"cats": {"POSITIVE": 1.0, "NEGATIVE": 0.0}}), | ||||
|     ("I'm so angry", {"cats": {"POSITIVE": 0.0, "NEGATIVE": 1.0}}), | ||||
|  | @ -109,3 +111,33 @@ def test_overfitting_IO(): | |||
|         cats2 = doc2.cats | ||||
|         assert cats2["POSITIVE"] > 0.9 | ||||
|         assert cats2["POSITIVE"] + cats2["NEGATIVE"] == pytest.approx(1.0, 0.1) | ||||
| 
 | ||||
| 
 | ||||
| # fmt: off | ||||
| @pytest.mark.parametrize( | ||||
|     "textcat_config", | ||||
|     [ | ||||
|         {"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": False, "ngram_size": 1, "no_output_layer": False}, | ||||
|         {"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": True, "ngram_size": 4, "no_output_layer": False}, | ||||
|         {"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": False, "ngram_size": 3, "no_output_layer": True}, | ||||
|         {"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": True, "ngram_size": 2, "no_output_layer": True}, | ||||
|         {"@architectures": "spacy.TextCat.v1", "exclusive_classes": False, "ngram_size": 1, "pretrained_vectors": False, "width": 64, "conv_depth": 2, "embed_size": 2000, "window_size": 2}, | ||||
|         {"@architectures": "spacy.TextCat.v1", "exclusive_classes": True, "ngram_size": 5, "pretrained_vectors": False, "width": 128, "conv_depth": 2, "embed_size": 2000, "window_size": 1}, | ||||
|         {"@architectures": "spacy.TextCat.v1", "exclusive_classes": True, "ngram_size": 2, "pretrained_vectors": False, "width": 32, "conv_depth": 3, "embed_size": 500, "window_size": 3}, | ||||
|         {"@architectures": "spacy.TextCatCNN.v1", "tok2vec": default_tok2vec(), "exclusive_classes": True}, | ||||
|         {"@architectures": "spacy.TextCatCNN.v1", "tok2vec": default_tok2vec(), "exclusive_classes": False}, | ||||
|     ], | ||||
| ) | ||||
| # fmt: on | ||||
| def test_textcat_configs(textcat_config): | ||||
|     pipe_config = {"model": textcat_config} | ||||
|     nlp = English() | ||||
|     textcat = nlp.create_pipe("textcat", pipe_config) | ||||
|     for _, annotations in TRAIN_DATA: | ||||
|         for label, value in annotations.get("cats").items(): | ||||
|             textcat.add_label(label) | ||||
|     nlp.add_pipe(textcat) | ||||
|     optimizer = nlp.begin_training() | ||||
|     for i in range(5): | ||||
|         losses = {} | ||||
|         nlp.update(TRAIN_DATA, sgd=optimizer, losses=losses) | ||||
|  |  | |||
|  | @ -4,8 +4,7 @@ import ctypes | |||
| from pathlib import Path | ||||
| from spacy import util | ||||
| from spacy import prefer_gpu, require_gpu | ||||
| from spacy.ml._layers import PrecomputableAffine | ||||
| from spacy.ml._layers import _backprop_precomputable_affine_padding | ||||
| from spacy.ml._precomputable_affine import PrecomputableAffine, _backprop_precomputable_affine_padding | ||||
| 
 | ||||
| 
 | ||||
| @pytest.fixture | ||||
|  |  | |||
|  | @ -4,18 +4,7 @@ from spacy.ml.models.tok2vec import build_Tok2Vec_model | |||
| from spacy.vocab import Vocab | ||||
| from spacy.tokens import Doc | ||||
| 
 | ||||
| 
 | ||||
| def get_batch(batch_size): | ||||
|     vocab = Vocab() | ||||
|     docs = [] | ||||
|     start = 0 | ||||
|     for size in range(1, batch_size + 1): | ||||
|         # Make the words numbers, so that they're distinct | ||||
|         # across the batch, and easy to track. | ||||
|         numbers = [str(i) for i in range(start, start + size)] | ||||
|         docs.append(Doc(vocab, words=numbers)) | ||||
|         start += size | ||||
|     return docs | ||||
| from .util import get_batch | ||||
| 
 | ||||
| 
 | ||||
| # This fails in Thinc v7.3.1. Need to push patch | ||||
|  | @ -75,7 +64,7 @@ def test_tok2vec_batch_sizes(batch_size, width, embed_size): | |||
| def test_tok2vec_configs(tok2vec_config): | ||||
|     docs = get_batch(3) | ||||
|     tok2vec = build_Tok2Vec_model(**tok2vec_config) | ||||
|     tok2vec.initialize() | ||||
|     tok2vec.initialize(docs) | ||||
|     vectors, backprop = tok2vec.begin_update(docs) | ||||
|     assert len(vectors) == len(docs) | ||||
|     assert vectors[0].shape == (len(docs[0]), tok2vec_config["width"]) | ||||
|  |  | |||
|  | @ -9,6 +9,8 @@ from spacy import Errors | |||
| from spacy.tokens import Doc, Span | ||||
| from spacy.attrs import POS, TAG, HEAD, DEP, LEMMA | ||||
| 
 | ||||
| from spacy.vocab import Vocab | ||||
| 
 | ||||
| 
 | ||||
| @contextlib.contextmanager | ||||
| def make_tempfile(mode="r"): | ||||
|  | @ -77,6 +79,19 @@ def get_doc( | |||
|     return doc | ||||
| 
 | ||||
| 
 | ||||
| def get_batch(batch_size): | ||||
|     vocab = Vocab() | ||||
|     docs = [] | ||||
|     start = 0 | ||||
|     for size in range(1, batch_size + 1): | ||||
|         # Make the words numbers, so that they're distinct | ||||
|         # across the batch, and easy to track. | ||||
|         numbers = [str(i) for i in range(start, start + size)] | ||||
|         docs.append(Doc(vocab, words=numbers)) | ||||
|         start += size | ||||
|     return docs | ||||
| 
 | ||||
| 
 | ||||
| def apply_transition_sequence(parser, doc, sequence): | ||||
|     """Perform a series of pre-specified transitions, to put the parser in a | ||||
|     desired state.""" | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user