diff --git a/MANIFEST.in b/MANIFEST.in index 1947b9140..e6d25284f 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,5 +1,5 @@ recursive-include include *.h -recursive-include spacy *.txt *.pyx *.pxd +recursive-include spacy *.pyx *.pxd *.txt *.cfg include LICENSE include README.md include bin/spacy diff --git a/bin/ud/ud_train.py b/bin/ud/ud_train.py index bda22088d..aa5050f3a 100644 --- a/bin/ud/ud_train.py +++ b/bin/ud/ud_train.py @@ -386,8 +386,8 @@ def _load_pretrained_tok2vec(nlp, loc): weights_data = file_.read() loaded = [] for name, component in nlp.pipeline: - if hasattr(component, "model") and hasattr(component.model, "tok2vec"): - component.tok2vec.from_bytes(weights_data) + if hasattr(component, "model") and component.model.has_ref("tok2vec"): + component.get_ref("tok2vec").from_bytes(weights_data) loaded.append(name) return loaded diff --git a/bin/wiki_entity_linking/train_descriptions.py b/bin/wiki_entity_linking/train_descriptions.py index d98bba565..b0cfbb4c6 100644 --- a/bin/wiki_entity_linking/train_descriptions.py +++ b/bin/wiki_entity_linking/train_descriptions.py @@ -1,13 +1,9 @@ -# coding: utf-8 from random import shuffle import logging import numpy as np -from thinc.model import Model -from thinc.api import chain -from thinc.loss import CosineDistance -from thinc.layers import Linear +from thinc.api import Model, chain, CosineDistance, Linear from spacy.util import create_default_optimizer diff --git a/examples/experiments/ptb-joint-pos-dep/bilstm_tok2vec.cfg b/examples/experiments/ptb-joint-pos-dep/bilstm_tok2vec.cfg index 8cd150868..4f1a915c5 100644 --- a/examples/experiments/ptb-joint-pos-dep/bilstm_tok2vec.cfg +++ b/examples/experiments/ptb-joint-pos-dep/bilstm_tok2vec.cfg @@ -39,25 +39,27 @@ factory = "tagger" factory = "parser" [nlp.pipeline.tagger.model] -@architectures = "tagger_model.v1" +@architectures = "spacy.Tagger.v1" [nlp.pipeline.tagger.model.tok2vec] -@architectures = "tok2vec_tensors.v1" +@architectures = "spacy.Tok2VecTensors.v1" width = ${nlp.pipeline.tok2vec.model:width} [nlp.pipeline.parser.model] -@architectures = "transition_based_parser.v1" +@architectures = "spacy.TransitionBasedParser.v1" nr_feature_tokens = 8 hidden_width = 64 maxout_pieces = 3 [nlp.pipeline.parser.model.tok2vec] -@architectures = "tok2vec_tensors.v1" +@architectures = "spacy.Tok2VecTensors.v1" width = ${nlp.pipeline.tok2vec.model:width} [nlp.pipeline.tok2vec.model] -@architectures = "hash_embed_bilstm.v1" +@architectures = "spacy.HashEmbedBiLSTM.v1" pretrained_vectors = ${nlp:vectors} width = 96 depth = 4 embed_size = 2000 +subword_features = true +char_embed = false diff --git a/examples/experiments/ptb-joint-pos-dep/defaults.cfg b/examples/experiments/ptb-joint-pos-dep/defaults.cfg index 6735284a7..2ceaab0be 100644 --- a/examples/experiments/ptb-joint-pos-dep/defaults.cfg +++ b/examples/experiments/ptb-joint-pos-dep/defaults.cfg @@ -39,27 +39,28 @@ factory = "tagger" factory = "parser" [nlp.pipeline.tagger.model] -@architectures = "tagger_model.v1" +@architectures = "spacy.Tagger.v1" [nlp.pipeline.tagger.model.tok2vec] -@architectures = "tok2vec_tensors.v1" +@architectures = "spacy.Tok2VecTensors.v1" width = ${nlp.pipeline.tok2vec.model:width} [nlp.pipeline.parser.model] -@architectures = "transition_based_parser.v1" +@architectures = "spacy.TransitionBasedParser.v1" nr_feature_tokens = 8 hidden_width = 64 maxout_pieces = 3 [nlp.pipeline.parser.model.tok2vec] -@architectures = "tok2vec_tensors.v1" +@architectures = "spacy.Tok2VecTensors.v1" width = ${nlp.pipeline.tok2vec.model:width} [nlp.pipeline.tok2vec.model] -@architectures = "hash_embed_cnn.v1" +@architectures = "spacy.HashEmbedCNN.v1" pretrained_vectors = ${nlp:vectors} width = 96 depth = 4 window_size = 1 embed_size = 2000 maxout_pieces = 3 +subword_features = true diff --git a/examples/training/pretrain_textcat.py b/examples/training/pretrain_textcat.py index 85d36fd66..0aefec9ef 100644 --- a/examples/training/pretrain_textcat.py +++ b/examples/training/pretrain_textcat.py @@ -20,9 +20,9 @@ import random import ml_datasets import spacy -from spacy.util import minibatch, use_gpu, compounding +from spacy.util import minibatch from spacy.pipeline import TextCategorizer -from spacy.ml.tok2vec import Tok2Vec +from spacy.ml.models.tok2vec import build_Tok2Vec_model import numpy @@ -65,9 +65,7 @@ def prefer_gpu(): def build_textcat_model(tok2vec, nr_class, width): - from thinc.model import Model - from thinc.layers import Softmax, chain, reduce_mean - from thinc.layers import list2ragged + from thinc.api import Model, Softmax, chain, reduce_mean, list2ragged with Model.define_operators({">>": chain}): model = ( @@ -76,7 +74,7 @@ def build_textcat_model(tok2vec, nr_class, width): >> reduce_mean() >> Softmax(nr_class, width) ) - model.tok2vec = tok2vec + model.set_ref("tok2vec", tok2vec) return model @@ -97,8 +95,9 @@ def create_pipeline(width, embed_size, vectors_model): textcat = TextCategorizer( nlp.vocab, labels=["POSITIVE", "NEGATIVE"], + # TODO: replace with config version model=build_textcat_model( - Tok2Vec(width=width, embed_size=embed_size), 2, width + build_Tok2Vec_model(width=width, embed_size=embed_size), 2, width ), ) @@ -121,7 +120,7 @@ def train_tensorizer(nlp, texts, dropout, n_iter): def train_textcat(nlp, n_texts, n_iter=10): textcat = nlp.get_pipe("textcat") - tok2vec_weights = textcat.model.tok2vec.to_bytes() + tok2vec_weights = textcat.model.get_ref("tok2vec").to_bytes() (train_texts, train_cats), (dev_texts, dev_cats) = load_textcat_data(limit=n_texts) print( "Using {} examples ({} training, {} evaluation)".format( @@ -135,7 +134,7 @@ def train_textcat(nlp, n_texts, n_iter=10): other_pipes = [pipe for pipe in nlp.pipe_names if pipe not in pipe_exceptions] with nlp.disable_pipes(*other_pipes): # only train textcat optimizer = nlp.begin_training() - textcat.model.tok2vec.from_bytes(tok2vec_weights) + textcat.model.get_ref("tok2vec").from_bytes(tok2vec_weights) print("Training the model...") print("{:^5}\t{:^5}\t{:^5}\t{:^5}".format("LOSS", "P", "R", "F")) for i in range(n_iter): diff --git a/examples/training/train_textcat.py b/examples/training/train_textcat.py index 4d402e04d..50c852ac1 100644 --- a/examples/training/train_textcat.py +++ b/examples/training/train_textcat.py @@ -74,7 +74,7 @@ def main(model=None, output_dir=None, n_iter=20, n_texts=2000, init_tok2vec=None optimizer = nlp.begin_training() if init_tok2vec is not None: with init_tok2vec.open("rb") as file_: - textcat.model.tok2vec.from_bytes(file_.read()) + textcat.model.get_ref("tok2vec").from_bytes(file_.read()) print("Training the model...") print("{:^5}\t{:^5}\t{:^5}\t{:^5}".format("LOSS", "P", "R", "F")) batch_sizes = compounding(4.0, 32.0, 1.001) diff --git a/pyproject.toml b/pyproject.toml index 71e523c7c..ee28d5d42 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ requires = [ "cymem>=2.0.2,<2.1.0", "preshed>=3.0.2,<3.1.0", "murmurhash>=0.28.0,<1.1.0", - "thinc==8.0.0a0", + "thinc==8.0.0a1", "blis>=0.4.0,<0.5.0" ] build-backend = "setuptools.build_meta" diff --git a/requirements.txt b/requirements.txt index f3a7cc162..09998cdc9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ # Our libraries cymem>=2.0.2,<2.1.0 preshed>=3.0.2,<3.1.0 -thinc==8.0.0a0 +thinc==8.0.0a1 blis>=0.4.0,<0.5.0 ml_datasets>=0.1.1 murmurhash>=0.28.0,<1.1.0 diff --git a/setup.cfg b/setup.cfg index 980269c35..7b3a468b6 100644 --- a/setup.cfg +++ b/setup.cfg @@ -36,13 +36,13 @@ setup_requires = cymem>=2.0.2,<2.1.0 preshed>=3.0.2,<3.1.0 murmurhash>=0.28.0,<1.1.0 - thinc==8.0.0a0 + thinc==8.0.0a1 install_requires = # Our libraries murmurhash>=0.28.0,<1.1.0 cymem>=2.0.2,<2.1.0 preshed>=3.0.2,<3.1.0 - thinc==8.0.0a0 + thinc==8.0.0a1 blis>=0.4.0,<0.5.0 wasabi>=0.4.0,<1.1.0 srsly>=2.0.0,<3.0.0 diff --git a/spacy/cli/pretrain.py b/spacy/cli/pretrain.py index 690e3107d..95d549254 100644 --- a/spacy/cli/pretrain.py +++ b/spacy/cli/pretrain.py @@ -11,10 +11,10 @@ import srsly from ..gold import Example from ..errors import Errors +from ..ml.models.multi_task import build_masked_language_model from ..tokens import Doc from ..attrs import ID, HEAD -from ..ml.component_models import Tok2Vec -from ..ml.component_models import masked_language_model +from ..ml.models.tok2vec import build_Tok2Vec_model from .. import util from ..util import create_default_optimizer from .train import _load_pretrained_tok2vec @@ -108,14 +108,19 @@ def pretrain( pretrained_vectors = None if not use_vectors else nlp.vocab.vectors model = create_pretraining_model( nlp, - Tok2Vec( + # TODO: replace with config + build_Tok2Vec_model( width, embed_rows, conv_depth=conv_depth, pretrained_vectors=pretrained_vectors, bilstm_depth=bilstm_depth, # Requires PyTorch. Experimental. subword_features=not use_chars, # Set to False for Chinese etc - cnn_maxout_pieces=cnn_pieces, # If set to 1, use Mish activation. + maxout_pieces=cnn_pieces, # If set to 1, use Mish activation. + window_size=1, + char_embed=False, + nM=64, + nC=8 ), ) # Load in pretrained weights @@ -152,7 +157,7 @@ def pretrain( is_temp_str = ".temp" if is_temp else "" with model.use_params(optimizer.averages): with (output_dir / f"model{epoch}{is_temp_str}.bin").open("wb") as file_: - file_.write(model.tok2vec.to_bytes()) + file_.write(model.get_ref("tok2vec").to_bytes()) log = { "nr_word": tracker.nr_word, "loss": tracker.loss, @@ -284,7 +289,7 @@ def create_pretraining_model(nlp, tok2vec): # "tok2vec" has to be the same set of processes as what the components do. tok2vec = chain(tok2vec, list2array()) model = chain(tok2vec, output_layer) - model = masked_language_model(nlp.vocab, model) + model = build_masked_language_model(nlp.vocab, model) model.set_ref("tok2vec", tok2vec) model.set_ref("output_layer", output_layer) model.initialize(X=[nlp.make_doc("Give it a doc to infer shapes")]) diff --git a/spacy/cli/train.py b/spacy/cli/train.py index 92f94b53d..5667bb905 100644 --- a/spacy/cli/train.py +++ b/spacy/cli/train.py @@ -9,7 +9,7 @@ from wasabi import msg import contextlib import random -from ..util import create_default_optimizer +from ..util import create_default_optimizer, registry from ..util import use_gpu as set_gpu from ..attrs import PROB, IS_OOV, CLUSTER, LANG from ..gold import GoldCorpus @@ -111,6 +111,8 @@ def train( eval_beam_widths.sort() has_beam_widths = eval_beam_widths != [1] + default_dir = Path(__file__).parent.parent / "ml" / "models" / "defaults" + # Set up the base model and pipeline. If a base model is specified, load # the model and make sure the pipeline matches the pipeline setting. If # training starts from a blank model, intitalize the language class. @@ -118,7 +120,6 @@ def train( msg.text(f"Training pipeline: {pipeline}") disabled_pipes = None pipes_added = False - msg.text(f"Training pipeline: {pipeline}") if use_gpu >= 0: activated_gpu = None try: @@ -140,16 +141,36 @@ def train( f"specified as `lang` argument ('{lang}') ", exits=1, ) + if vectors: + msg.text(f"Loading vectors from model '{vectors}'") + + nlp.disable_pipes([p for p in nlp.pipe_names if p not in pipeline]) for pipe in pipeline: - pipe_cfg = {} + # first, create the model. + # Bit of a hack after the refactor to get the vectors into a default config + # use train-from-config instead :-) if pipe == "parser": - pipe_cfg = {"learn_tokens": learn_tokens} + config_loc = default_dir / "parser_defaults.cfg" + elif pipe == "tagger": + config_loc = default_dir / "tagger_defaults.cfg" + elif pipe == "ner": + config_loc = default_dir / "ner_defaults.cfg" elif pipe == "textcat": - pipe_cfg = { - "exclusive_classes": not textcat_multilabel, - "architecture": textcat_arch, - "positive_label": textcat_positive_label, - } + config_loc = default_dir / "textcat_defaults.cfg" + else: + raise ValueError(f"Component {pipe} currently not supported.") + pipe_cfg = util.load_config(config_loc, create_objects=False) + if vectors: + pretrained_config = {'@architectures': 'spacy.VocabVectors.v1', 'name': vectors} + pipe_cfg["model"]["tok2vec"]["pretrained_vectors"] = pretrained_config + + if pipe == "parser": + pipe_cfg["learn_tokens"] = learn_tokens + elif pipe == "textcat": + pipe_cfg["exclusive_classes"] = not textcat_multilabel + pipe_cfg["architecture"] = textcat_arch + pipe_cfg["positive_label"] = textcat_positive_label + if pipe not in nlp.pipe_names: msg.text(f"Adding component to base model '{pipe}'") nlp.add_pipe(nlp.create_pipe(pipe, config=pipe_cfg)) @@ -181,26 +202,42 @@ def train( msg.text(f"Starting with blank model '{lang}'") lang_cls = util.get_lang_class(lang) nlp = lang_cls() + + if vectors: + msg.text(f"Loading vectors from model '{vectors}'") + for pipe in pipeline: + # first, create the model. + # Bit of a hack after the refactor to get the vectors into a default config + # use train-from-config instead :-) if pipe == "parser": - pipe_cfg = {"learn_tokens": learn_tokens} + config_loc = default_dir / "parser_defaults.cfg" + elif pipe == "tagger": + config_loc = default_dir / "tagger_defaults.cfg" + elif pipe == "ner": + config_loc = default_dir / "ner_defaults.cfg" elif pipe == "textcat": - pipe_cfg = { - "exclusive_classes": not textcat_multilabel, - "architecture": textcat_arch, - "positive_label": textcat_positive_label, - } + config_loc = default_dir / "textcat_defaults.cfg" else: - pipe_cfg = {} - nlp.add_pipe(nlp.create_pipe(pipe, config=pipe_cfg)) + raise ValueError(f"Component {pipe} currently not supported.") + pipe_cfg = util.load_config(config_loc, create_objects=False) + if vectors: + pretrained_config = {'@architectures': 'spacy.VocabVectors.v1', 'name': vectors} + pipe_cfg["model"]["tok2vec"]["pretrained_vectors"] = pretrained_config + + if pipe == "parser": + pipe_cfg["learn_tokens"] = learn_tokens + elif pipe == "textcat": + pipe_cfg["exclusive_classes"] = not textcat_multilabel + pipe_cfg["architecture"] = textcat_arch + pipe_cfg["positive_label"] = textcat_positive_label + + pipe = nlp.create_pipe(pipe, config=pipe_cfg) + nlp.add_pipe(pipe) # Update tag map with provided mapping nlp.vocab.morphology.tag_map.update(tag_map) - if vectors: - msg.text(f"Loading vector from model '{vectors}'") - _load_vectors(nlp, vectors) - # Multitask objectives multitask_options = [("parser", parser_multitasks), ("ner", entity_multitasks)] for pipe_name, multitasks in multitask_options: @@ -228,7 +265,7 @@ def train( optimizer = nlp.begin_training(lambda: corpus.train_examples, **cfg) nlp._optimizer = None - # Load in pretrained weights + # Load in pretrained weights (TODO: this may be broken in the config rewrite) if init_tok2vec is not None: components = _load_pretrained_tok2vec(nlp, init_tok2vec) msg.text(f"Loaded pretrained tok2vec for: {components}") @@ -531,7 +568,7 @@ def _create_progress_bar(total): def _load_vectors(nlp, vectors): - util.load_model(vectors, vocab=nlp.vocab) + loaded_model = util.load_model(vectors, vocab=nlp.vocab) for lex in nlp.vocab: values = {} for attr, func in nlp.vocab.lex_attr_getters.items(): @@ -541,6 +578,7 @@ def _load_vectors(nlp, vectors): values[lex.vocab.strings[attr]] = func(lex.orth_) lex.set_attrs(**values) lex.is_oov = False + return loaded_model def _load_pretrained_tok2vec(nlp, loc): @@ -551,8 +589,8 @@ def _load_pretrained_tok2vec(nlp, loc): weights_data = file_.read() loaded = [] for name, component in nlp.pipeline: - if hasattr(component, "model") and hasattr(component.model, "tok2vec"): - component.tok2vec.from_bytes(weights_data) + if hasattr(component, "model") and component.model.has_ref("tok2vec"): + component.get_ref("tok2vec").from_bytes(weights_data) loaded.append(name) return loaded diff --git a/spacy/cli/train_from_config.py b/spacy/cli/train_from_config.py index 9150da356..0dba8a962 100644 --- a/spacy/cli/train_from_config.py +++ b/spacy/cli/train_from_config.py @@ -1,19 +1,17 @@ from typing import Optional, Dict, List, Union, Sequence +from pydantic import BaseModel, FilePath, StrictInt + import plac -from wasabi import msg +import tqdm from pathlib import Path + +from wasabi import msg import thinc import thinc.schedules from thinc.api import Model -from pydantic import BaseModel, FilePath, StrictInt -import tqdm -# TODO: relative imports? -import spacy -from spacy.gold import GoldCorpus -from spacy.pipeline.tok2vec import Tok2VecListener -from spacy.ml import component_models -from spacy import util +from ..gold import GoldCorpus +from .. import util registry = util.registry @@ -57,23 +55,24 @@ factory = "tok2vec" factory = "ner" [nlp.pipeline.ner.model] -@architectures = "transition_based_ner.v1" +@architectures = "spacy.TransitionBasedParser.v1" nr_feature_tokens = 3 hidden_width = 64 maxout_pieces = 3 [nlp.pipeline.ner.model.tok2vec] -@architectures = "tok2vec_tensors.v1" +@architectures = "spacy.Tok2VecTensors.v1" width = ${nlp.pipeline.tok2vec.model:width} [nlp.pipeline.tok2vec.model] -@architectures = "hash_embed_cnn.v1" +@architectures = "spacy.HashEmbedCNN.v1" pretrained_vectors = ${nlp:vectors} width = 128 depth = 4 window_size = 1 embed_size = 10000 maxout_pieces = 3 +subword_features = true """ @@ -113,65 +112,6 @@ class ConfigSchema(BaseModel): extra = "allow" -# Of course, these would normally decorate the functions where they're defined. -# But for now... -@registry.architectures.register("hash_embed_cnn.v1") -def hash_embed_cnn( - pretrained_vectors, width, depth, embed_size, maxout_pieces, window_size -): - return component_models.Tok2Vec( - width=width, - embed_size=embed_size, - pretrained_vectors=pretrained_vectors, - conv_depth=depth, - cnn_maxout_pieces=maxout_pieces, - bilstm_depth=0, - window_size=window_size, - ) - - -@registry.architectures.register("hash_embed_bilstm.v1") -def hash_embed_bilstm_v1(pretrained_vectors, width, depth, embed_size): - return component_models.Tok2Vec( - width=width, - embed_size=embed_size, - pretrained_vectors=pretrained_vectors, - bilstm_depth=depth, - conv_depth=0, - cnn_maxout_pieces=0, - ) - - -@registry.architectures.register("tagger_model.v1") -def build_tagger_model_v1(tok2vec): - return component_models.build_tagger_model(nr_class=None, tok2vec=tok2vec) - - -@registry.architectures.register("transition_based_parser.v1") -def create_tb_parser_model( - tok2vec: Model, - nr_feature_tokens: StrictInt = 3, - hidden_width: StrictInt = 64, - maxout_pieces: StrictInt = 3, -): - from thinc.api import Linear, chain, list2array, use_ops, zero_init - from spacy.ml._layers import PrecomputableAffine - from spacy.syntax._parser_model import ParserModel - - token_vector_width = tok2vec.get_dim("nO") - tok2vec = chain(tok2vec, list2array()) - tok2vec.set_dim("nO", token_vector_width) - - lower = PrecomputableAffine( - hidden_width, nF=nr_feature_tokens, nI=tok2vec.get_dim("nO"), nP=maxout_pieces - ) - lower.set_dim("nP", maxout_pieces) - with use_ops("numpy"): - # Initialize weights at zero, as it's a classification layer. - upper = Linear(init_W=zero_init) - return ParserModel(tok2vec, lower, upper) - - @plac.annotations( # fmt: off train_path=("Location of JSON-formatted training data", "positional", None, Path), @@ -224,23 +164,25 @@ def train_from_config( config_path, data_paths, raw_text=None, meta_path=None, output_path=None, ): msg.info(f"Loading config from: {config_path}") - config = util.load_from_config(config_path, create_objects=True) + config = util.load_config(config_path, create_objects=True) use_gpu = config["training"]["use_gpu"] if use_gpu >= 0: msg.info("Using GPU") else: msg.info("Using CPU") msg.info("Creating nlp from config") - nlp = create_nlp_from_config(**config["nlp"]) + nlp_config = util.load_config(config_path, create_objects=False)["nlp"] + nlp = util.load_model_from_config(nlp_config) optimizer = config["optimizer"] - limit = config["training"]["limit"] + training = config["training"] + limit = training["limit"] msg.info("Loading training corpus") corpus = GoldCorpus(data_paths["train"], data_paths["dev"], limit=limit) msg.info("Initializing the nlp pipeline") nlp.begin_training(lambda: corpus.train_examples, device=use_gpu) - train_batches = create_train_batches(nlp, corpus, config["training"]) - evaluate = create_evaluation_callback(nlp, optimizer, corpus, config["training"]) + train_batches = create_train_batches(nlp, corpus, training) + evaluate = create_evaluation_callback(nlp, optimizer, corpus, training) # Create iterator, which yields out info after each optimization step. msg.info("Start training") @@ -249,16 +191,16 @@ def train_from_config( optimizer, train_batches, evaluate, - config["training"]["dropout"], - config["training"]["patience"], - config["training"]["eval_frequency"], + training["dropout"], + training["patience"], + training["eval_frequency"], ) msg.info(f"Training. Initial learn rate: {optimizer.learn_rate}") - print_row = setup_printer(config) + print_row = setup_printer(training, nlp) try: - progress = tqdm.tqdm(total=config["training"]["eval_frequency"], leave=False) + progress = tqdm.tqdm(total=training["eval_frequency"], leave=False) for batch, info, is_best_checkpoint in training_step_iterator: progress.update(1) if is_best_checkpoint is not None: @@ -266,9 +208,7 @@ def train_from_config( print_row(info) if is_best_checkpoint and output_path is not None: nlp.to_disk(output_path) - progress = tqdm.tqdm( - total=config["training"]["eval_frequency"], leave=False - ) + progress = tqdm.tqdm(total=training["eval_frequency"], leave=False) finally: if output_path is not None: with nlp.use_params(optimizer.averages): @@ -280,18 +220,6 @@ def train_from_config( # msg.good("Created best model", best_model_path) -def create_nlp_from_config(lang, vectors, pipeline): - lang_class = spacy.util.get_lang_class(lang) - nlp = lang_class() - if vectors is not None: - spacy.cli.train._load_vectors(nlp, vectors) - for name, component_cfg in pipeline.items(): - factory = component_cfg.pop("factory") - component = nlp.create_pipe(factory, config=component_cfg) - nlp.add_pipe(component, name=name) - return nlp - - def create_train_batches(nlp, corpus, cfg): while True: train_examples = corpus.train_dataset( @@ -405,10 +333,10 @@ def subdivide_batch(batch): return [batch] -def setup_printer(config): - score_cols = config["training"]["scores"] +def setup_printer(training, nlp): + score_cols = training["scores"] score_widths = [max(len(col), 6) for col in score_cols] - loss_cols = [f"Loss {pipe}" for pipe in config["nlp"]["pipeline"]] + loss_cols = [f"Loss {pipe}" for pipe in nlp.pipe_names] loss_widths = [max(len(col), 8) for col in loss_cols] table_header = ["#"] + loss_cols + score_cols + ["Score"] table_header = [col.upper() for col in table_header] @@ -420,20 +348,13 @@ def setup_printer(config): def print_row(info): losses = [ - "{0:.2f}".format(info["losses"].get(col, 0.0)) - for col in config["nlp"]["pipeline"] + "{0:.2f}".format(info["losses"].get(pipe_name, 0.0)) + for pipe_name in nlp.pipe_names ] scores = [ - "{0:.2f}".format(info["other_scores"].get(col, 0.0)) - for col in config["training"]["scores"] + "{0:.2f}".format(info["other_scores"].get(col, 0.0)) for col in score_cols ] data = [info["step"]] + losses + scores + ["{0:.2f}".format(info["score"])] msg.row(data, widths=table_widths, aligns=table_aligns) return print_row - - -@registry.architectures.register("tok2vec_tensors.v1") -def tok2vec_tensors_v1(width): - tok2vec = Tok2VecListener("tok2vec", width=width) - return tok2vec diff --git a/spacy/errors.py b/spacy/errors.py index 7a4953cce..6afbfc3c6 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -106,6 +106,12 @@ class Warnings(object): "Provide features as a dict {{\"Field1\": \"Value1,Value2\"}} or " "string \"Field1=Value1,Value2|Field2=Value3\".") + # TODO: fix numbering after merging develop into master + W098 = ("No Model config was provided to create the '{name}' component, " + "so a default configuration was used.") + W099 = ("Expected 'dict' type for the 'model' argument of pipe '{pipe}', " + "but got '{type}' instead, so ignoring it.") + @add_codes class Errors(object): @@ -227,7 +233,7 @@ class Errors(object): E050 = ("Can't find model '{name}'. It doesn't seem to be a Python " "package or a valid path to a data directory.") E052 = ("Can't find model directory: {path}") - E053 = ("Could not read meta.json from {path}") + E053 = ("Could not read {name} from {path}") E054 = ("No valid '{setting}' setting found in model meta.json.") E055 = ("Invalid ORTH value in exception:\nKey: {key}\nOrths: {orths}") E056 = ("Invalid tokenizer exception: ORTH values combined don't match " @@ -345,8 +351,8 @@ class Errors(object): E108 = ("As of spaCy v2.1, the pipe name `sbd` has been deprecated " "in favor of the pipe name `sentencizer`, which does the same " "thing. For example, use `nlp.create_pipeline('sentencizer')`") - E109 = ("Model for component '{name}' not initialized. Did you forget to " - "load a model, or forget to call begin_training()?") + E109 = ("Component '{name}' could not be run. Did you forget to " + "call begin_training()?") E110 = ("Invalid displaCy render wrapper. Expected callable, got: {obj}") E111 = ("Pickling a token is not supported, because tokens are only views " "of the parent Doc and can't exist on their own. A pickled token " @@ -532,6 +538,9 @@ class Errors(object): "make sure the gold EL data refers to valid results of the " "named entity recognizer in the `nlp` pipeline.") # TODO: fix numbering after merging develop into master + E993 = ("The config for 'nlp' should include either a key 'name' to " + "refer to an existing model by name or path, or a key 'lang' " + "to create a new blank model.") E996 = ("Could not parse {file}: {msg}") E997 = ("Tokenizer special cases are not allowed to modify the text. " "This would map '{chunk}' to '{orth}' given token attributes " diff --git a/spacy/language.py b/spacy/language.py index 1c6014cec..83f8c9d21 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -4,7 +4,9 @@ import weakref import functools from contextlib import contextmanager from copy import copy, deepcopy -from thinc.api import get_current_ops +from pathlib import Path + +from thinc.api import get_current_ops, Config import srsly import multiprocessing as mp from itertools import chain, cycle @@ -16,7 +18,7 @@ from .lookups import Lookups from .analysis import analyze_pipes, analyze_all_pipes, validate_attrs from .gold import Example from .scorer import Scorer -from .util import link_vectors_to_models, create_default_optimizer +from .util import link_vectors_to_models, create_default_optimizer, registry from .attrs import IS_STOP, LANG from .lang.punctuation import TOKENIZER_PREFIXES, TOKENIZER_SUFFIXES from .lang.punctuation import TOKENIZER_INFIXES @@ -24,7 +26,7 @@ from .lang.tokenizer_exceptions import TOKEN_MATCH from .lang.tag_map import TAG_MAP from .tokens import Doc from .lang.lex_attrs import LEX_ATTRS, is_stop -from .errors import Errors, Warnings, deprecation_warning +from .errors import Errors, Warnings, deprecation_warning, user_warning from . import util from . import about @@ -128,7 +130,7 @@ class Language(object): factories = {"tokenizer": lambda nlp: nlp.Defaults.create_tokenizer(nlp)} def __init__( - self, vocab=True, make_doc=True, max_length=10 ** 6, meta={}, **kwargs + self, vocab=True, make_doc=True, max_length=10 ** 6, meta={}, config=None, **kwargs ): """Initialise a Language object. @@ -138,6 +140,7 @@ class Language(object): object. Usually a `Tokenizer`. meta (dict): Custom meta data for the Language class. Is written to by models to add model meta data. + config (Config): Configuration data for creating the pipeline components. max_length (int) : Maximum number of characters in a single text. The current v2 models may run out memory on extremely long texts, due to large internal @@ -152,6 +155,9 @@ class Language(object): user_factories = util.registry.factories.get_all() self.factories.update(user_factories) self._meta = dict(meta) + self._config = config + if not self._config: + self._config = Config() self._path = None if vocab is True: factory = self.Defaults.create_vocab @@ -170,6 +176,21 @@ class Language(object): self.max_length = max_length self._optimizer = None + from .ml.models.defaults import default_tagger_config, default_parser_config, default_ner_config, \ + default_textcat_config, default_nel_config, default_morphologizer_config, default_sentrec_config, \ + default_tensorizer_config, default_tok2vec_config + + self.defaults = {"tagger": default_tagger_config(), + "parser": default_parser_config(), + "ner": default_ner_config(), + "textcat": default_textcat_config(), + "entity_linker": default_nel_config(), + "morphologizer": default_morphologizer_config(), + "sentrec": default_sentrec_config(), + "tensorizer": default_tensorizer_config(), + "tok2vec": default_tok2vec_config(), + } + @property def path(self): return self._path @@ -203,6 +224,10 @@ class Language(object): def meta(self, value): self._meta = value + @property + def config(self): + return self._config + # Conveniences to access pipeline components # Shouldn't be used anymore! @property @@ -293,7 +318,24 @@ class Language(object): else: raise KeyError(Errors.E002.format(name=name)) factory = self.factories[name] - return factory(self, **config) + default_config = self.defaults.get(name, None) + + # transform the model's config to an actual Model + model_cfg = None + if "model" in config: + model_cfg = config["model"] + if not isinstance(model_cfg, dict): + user_warning(Warnings.W099.format(type=type(model_cfg), pipe=name)) + model_cfg = None + del config["model"] + if model_cfg is None and default_config is not None: + user_warning(Warnings.W098) + model_cfg = default_config["model"] + model = None + if model_cfg is not None: + self.config[name] = {"model": model_cfg} + model = registry.make_from_config({"model": model_cfg}, validate=True)["model"] + return factory(self, model, **config) def add_pipe( self, component, name=None, before=None, after=None, first=None, last=None @@ -430,7 +472,10 @@ class Language(object): continue if not hasattr(proc, "__call__"): raise ValueError(Errors.E003.format(component=type(proc), name=name)) - doc = proc(doc, **component_cfg.get(name, {})) + try: + doc = proc(doc, **component_cfg.get(name, {})) + except KeyError: + raise ValueError(Errors.E109.format(name=name)) if doc is None: raise ValueError(Errors.E005.format(name=name)) return doc @@ -578,9 +623,6 @@ class Language(object): ops = get_current_ops() self.vocab.vectors.data = ops.asarray(self.vocab.vectors.data) link_vectors_to_models(self.vocab) - if self.vocab.vectors.data.shape[1]: - cfg["pretrained_vectors"] = self.vocab.vectors.name - cfg["pretrained_dims"] = self.vocab.vectors.data.shape[1] if sgd is None: sgd = create_default_optimizer() self._optimizer = sgd @@ -611,8 +653,6 @@ class Language(object): if self.vocab.vectors.data.shape[1] >= 1: self.vocab.vectors.data = ops.asarray(self.vocab.vectors.data) link_vectors_to_models(self.vocab) - if self.vocab.vectors.data.shape[1]: - cfg["pretrained_vectors"] = self.vocab.vectors if sgd is None: sgd = create_default_optimizer() self._optimizer = sgd @@ -868,6 +908,7 @@ class Language(object): serializers["meta.json"] = lambda p: p.open("w").write( srsly.json_dumps(self.meta) ) + serializers["config.cfg"] = lambda p: self.config.to_disk(p) for name, proc in self.pipeline: if not hasattr(proc, "name"): continue @@ -895,6 +936,8 @@ class Language(object): exclude = disable path = util.ensure_path(path) deserializers = {} + if Path(path / "config.cfg").exists(): + deserializers["config.cfg"] = lambda p: self.config.from_disk(p) deserializers["meta.json"] = lambda p: self.meta.update(srsly.read_json(p)) deserializers["vocab"] = lambda p: self.vocab.from_disk( p @@ -933,6 +976,7 @@ class Language(object): serializers["vocab"] = lambda: self.vocab.to_bytes() serializers["tokenizer"] = lambda: self.tokenizer.to_bytes(exclude=["vocab"]) serializers["meta.json"] = lambda: srsly.json_dumps(self.meta) + serializers["config.cfg"] = lambda: self.config.to_bytes() for name, proc in self.pipeline: if name in exclude: continue @@ -955,6 +999,7 @@ class Language(object): deprecation_warning(Warnings.W014) exclude = disable deserializers = {} + deserializers["config.cfg"] = lambda b: self.config.from_bytes(b) deserializers["meta.json"] = lambda b: self.meta.update(srsly.json_loads(b)) deserializers["vocab"] = lambda b: self.vocab.from_bytes( b @@ -981,8 +1026,8 @@ class component(object): and class components and will automatically register components in the Language.factories. If the component is a class and needs access to the nlp object or config parameters, it can expose a from_nlp classmethod - that takes the nlp object and **cfg arguments and returns the initialized - component. + that takes the nlp & model objects and **cfg arguments, and returns the + initialized component. """ # NB: This decorator needs to live here, because it needs to write to @@ -1011,9 +1056,9 @@ class component(object): obj.requires = self.requires obj.retokenizes = self.retokenizes - def factory(nlp, **cfg): + def factory(nlp, model, **cfg): if hasattr(obj, "from_nlp"): - return obj.from_nlp(nlp, **cfg) + return obj.from_nlp(nlp, model, **cfg) elif isinstance(obj, type): return obj() return obj diff --git a/spacy/ml/component_models.py b/spacy/ml/component_models.py deleted file mode 100644 index 8c694f950..000000000 --- a/spacy/ml/component_models.py +++ /dev/null @@ -1,227 +0,0 @@ -from spacy import util -from spacy.ml.extract_ngrams import extract_ngrams - -from ..attrs import ID, ORTH, NORM, PREFIX, SUFFIX, SHAPE -from ..errors import Errors -from ._character_embed import CharacterEmbed - -from thinc.api import Model, Maxout, Linear, residual, reduce_mean, list2ragged -from thinc.api import PyTorchLSTM, add, MultiSoftmax, HashEmbed, StaticVectors -from thinc.api import expand_window, FeatureExtractor, SparseLinear, chain -from thinc.api import clone, concatenate, with_array, Softmax, Logistic, uniqued -from thinc.api import zero_init - - -def build_text_classifier(arch, config): - if arch == "cnn": - return build_simple_cnn_text_classifier(**config) - elif arch == "bow": - return build_bow_text_classifier(**config) - else: - raise ValueError("Unexpected textcat arch") - - -def build_simple_cnn_text_classifier(tok2vec, nr_class, exclusive_classes, **cfg): - """ - Build a simple CNN text classifier, given a token-to-vector model as inputs. - If exclusive_classes=True, a softmax non-linearity is applied, so that the - outputs sum to 1. If exclusive_classes=False, a logistic non-linearity - is applied instead, so that outputs are in the range [0, 1]. - """ - with Model.define_operators({">>": chain}): - if exclusive_classes: - output_layer = Softmax(nO=nr_class, nI=tok2vec.get_dim("nO")) - else: - # TODO: experiment with init_w=zero_init - output_layer = Linear(nO=nr_class, nI=tok2vec.get_dim("nO")) >> Logistic() - model = tok2vec >> list2ragged() >> reduce_mean() >> output_layer - model.set_ref("tok2vec", tok2vec) - model.set_dim("nO", nr_class) - return model - - -def build_bow_text_classifier( - nr_class, exclusive_classes, ngram_size=1, no_output_layer=False, **cfg -): - with Model.define_operators({">>": chain}): - model = extract_ngrams(ngram_size, attr=ORTH) >> SparseLinear(nr_class) - model.to_cpu() - if not no_output_layer: - output_layer = ( - Softmax(nO=nr_class) if exclusive_classes else Logistic(nO=nr_class) - ) - output_layer.to_cpu() - model = model >> output_layer - model.set_dim("nO", nr_class) - return model - - -def build_nel_encoder(embed_width, hidden_width, ner_types, **cfg): - if "entity_width" not in cfg: - raise ValueError(Errors.E144.format(param="entity_width")) - - conv_depth = cfg.get("conv_depth", 2) - cnn_maxout_pieces = cfg.get("cnn_maxout_pieces", 3) - pretrained_vectors = cfg.get("pretrained_vectors", None) - context_width = cfg.get("entity_width") - - with Model.define_operators({">>": chain, "**": clone}): - nel_tok2vec = Tok2Vec( - width=hidden_width, - embed_size=embed_width, - pretrained_vectors=pretrained_vectors, - cnn_maxout_pieces=cnn_maxout_pieces, - subword_features=True, - conv_depth=conv_depth, - bilstm_depth=0, - ) - - model = ( - nel_tok2vec - >> list2ragged() - >> reduce_mean() - >> residual(Maxout(nO=hidden_width, nI=hidden_width, nP=2, dropout=0.0)) - >> Linear(nO=context_width, nI=hidden_width) - ) - model.initialize() - - model.set_ref("tok2vec", nel_tok2vec) - model.set_dim("nO", context_width) - return model - - -def masked_language_model(*args, **kwargs): - raise NotImplementedError - - -def build_tagger_model(nr_class, tok2vec): - token_vector_width = tok2vec.get_dim("nO") - # TODO: glorot_uniform_init seems to work a bit better than zero_init here?! - softmax = with_array(Softmax(nO=nr_class, nI=token_vector_width, init_W=zero_init)) - model = chain(tok2vec, softmax) - model.set_ref("tok2vec", tok2vec) - model.set_ref("softmax", softmax) - return model - - -def build_morphologizer_model(class_nums, **cfg): - embed_size = util.env_opt("embed_size", 7000) - if "token_vector_width" in cfg: - token_vector_width = cfg["token_vector_width"] - else: - token_vector_width = util.env_opt("token_vector_width", 128) - pretrained_vectors = cfg.get("pretrained_vectors") - char_embed = cfg.get("char_embed", True) - with Model.define_operators({">>": chain, "+": add, "**": clone}): - if "tok2vec" in cfg: - tok2vec = cfg["tok2vec"] - else: - tok2vec = Tok2Vec( - token_vector_width, - embed_size, - char_embed=char_embed, - pretrained_vectors=pretrained_vectors, - ) - softmax = with_array(MultiSoftmax(nOs=class_nums, nI=token_vector_width)) - model = tok2vec >> softmax - model.set_ref("tok2vec", tok2vec) - model.set_ref("softmax", softmax) - return model - - -def Tok2Vec( - width, - embed_size, - pretrained_vectors=None, - window_size=1, - cnn_maxout_pieces=3, - subword_features=True, - char_embed=False, - conv_depth=4, - bilstm_depth=0, -): - if char_embed: - subword_features = False - cols = [ID, NORM, PREFIX, SUFFIX, SHAPE, ORTH] - with Model.define_operators({">>": chain, "|": concatenate, "**": clone}): - norm = HashEmbed(nO=width, nV=embed_size, column=cols.index(NORM), dropout=0.0) - if subword_features: - prefix = HashEmbed( - nO=width, nV=embed_size // 2, column=cols.index(PREFIX), dropout=0.0 - ) - suffix = HashEmbed( - nO=width, nV=embed_size // 2, column=cols.index(SUFFIX), dropout=0.0 - ) - shape = HashEmbed( - nO=width, nV=embed_size // 2, column=cols.index(SHAPE), dropout=0.0 - ) - else: - prefix, suffix, shape = (None, None, None) - if pretrained_vectors is not None: - glove = StaticVectors( - vectors=pretrained_vectors, nO=width, column=cols.index(ID), dropout=0.0 - ) - - if subword_features: - embed = uniqued( - (glove | norm | prefix | suffix | shape) - >> Maxout( - nO=width, nI=width * 5, nP=3, dropout=0.0, normalize=True - ), - column=cols.index(ORTH), - ) - else: - embed = uniqued( - (glove | norm) - >> Maxout( - nO=width, nI=width * 2, nP=3, dropout=0.0, normalize=True - ), - column=cols.index(ORTH), - ) - elif subword_features: - embed = uniqued( - concatenate(norm, prefix, suffix, shape) - >> Maxout(nO=width, nI=width * 4, nP=3, dropout=0.0, normalize=True), - column=cols.index(ORTH), - ) - elif char_embed: - embed = CharacterEmbed(nM=64, nC=8) | FeatureExtractor(cols) >> with_array( - norm - ) - reduce_dimensions = Maxout( - nO=width, - nI=64 * 8 + width, - nP=cnn_maxout_pieces, - dropout=0.0, - normalize=True, - ) - else: - embed = norm - - convolution = residual( - expand_window(window_size=window_size) - >> Maxout( - nO=width, - nI=width * 3, - nP=cnn_maxout_pieces, - dropout=0.0, - normalize=True, - ) - ) - if char_embed: - tok2vec = embed >> with_array( - reduce_dimensions >> convolution ** conv_depth, pad=conv_depth - ) - else: - tok2vec = FeatureExtractor(cols) >> with_array( - embed >> convolution ** conv_depth, pad=conv_depth - ) - - if bilstm_depth >= 1: - tok2vec = tok2vec >> PyTorchLSTM( - nO=width, nI=width, depth=bilstm_depth, bi=True - ) - # Work around thinc API limitations :(. TODO: Revise in Thinc 7 - tok2vec.set_dim("nO", width) - tok2vec.set_ref("embed", embed) - return tok2vec diff --git a/spacy/ml/models/__init__.py b/spacy/ml/models/__init__.py new file mode 100644 index 000000000..56696d581 --- /dev/null +++ b/spacy/ml/models/__init__.py @@ -0,0 +1,6 @@ +from .entity_linker import * +from .parser import * +from .tagger import * +from .tensorizer import * +from .textcat import * +from .tok2vec import * diff --git a/spacy/ml/models/defaults/__init__.py b/spacy/ml/models/defaults/__init__.py new file mode 100644 index 000000000..9af4da87d --- /dev/null +++ b/spacy/ml/models/defaults/__init__.py @@ -0,0 +1,93 @@ +from pathlib import Path + +from .... import util + + +def default_nel_config(): + loc = Path(__file__).parent / "entity_linker_defaults.cfg" + return util.load_config(loc, create_objects=False) + + +def default_nel(): + loc = Path(__file__).parent / "entity_linker_defaults.cfg" + return util.load_config(loc, create_objects=True)["model"] + + +def default_morphologizer_config(): + loc = Path(__file__).parent / "morphologizer_defaults.cfg" + return util.load_config(loc, create_objects=False) + + +def default_morphologizer(): + loc = Path(__file__).parent / "morphologizer_defaults.cfg" + return util.load_config(loc, create_objects=True)["model"] + + +def default_parser_config(): + loc = Path(__file__).parent / "parser_defaults.cfg" + return util.load_config(loc, create_objects=False) + + +def default_parser(): + loc = Path(__file__).parent / "parser_defaults.cfg" + return util.load_config(loc, create_objects=True)["model"] + + +def default_ner_config(): + loc = Path(__file__).parent / "ner_defaults.cfg" + return util.load_config(loc, create_objects=False) + + +def default_ner(): + loc = Path(__file__).parent / "ner_defaults.cfg" + return util.load_config(loc, create_objects=True)["model"] + + +def default_sentrec_config(): + loc = Path(__file__).parent / "sentrec_defaults.cfg" + return util.load_config(loc, create_objects=False) + + +def default_sentrec(): + loc = Path(__file__).parent / "sentrec_defaults.cfg" + return util.load_config(loc, create_objects=True)["model"] + + +def default_tagger_config(): + loc = Path(__file__).parent / "tagger_defaults.cfg" + return util.load_config(loc, create_objects=False) + + +def default_tagger(): + loc = Path(__file__).parent / "tagger_defaults.cfg" + return util.load_config(loc, create_objects=True)["model"] + + +def default_tensorizer_config(): + loc = Path(__file__).parent / "tensorizer_defaults.cfg" + return util.load_config(loc, create_objects=False) + + +def default_tensorizer(): + loc = Path(__file__).parent / "tensorizer_defaults.cfg" + return util.load_config(loc, create_objects=True)["model"] + + +def default_textcat_config(): + loc = Path(__file__).parent / "textcat_defaults.cfg" + return util.load_config(loc, create_objects=False) + + +def default_textcat(): + loc = Path(__file__).parent / "textcat_defaults.cfg" + return util.load_config(loc, create_objects=True)["model"] + + +def default_tok2vec_config(): + loc = Path(__file__).parent / "tok2vec_defaults.cfg" + return util.load_config(loc, create_objects=False) + + +def default_tok2vec(): + loc = Path(__file__).parent / "tok2vec_defaults.cfg" + return util.load_config(loc, create_objects=True)["model"] diff --git a/spacy/ml/models/defaults/entity_linker_defaults.cfg b/spacy/ml/models/defaults/entity_linker_defaults.cfg new file mode 100644 index 000000000..6a591ec3e --- /dev/null +++ b/spacy/ml/models/defaults/entity_linker_defaults.cfg @@ -0,0 +1,12 @@ +[model] +@architectures = "spacy.EntityLinker.v1" + +[model.tok2vec] +@architectures = "spacy.HashEmbedCNN.v1" +pretrained_vectors = null +width = 96 +depth = 2 +embed_size = 300 +window_size = 1 +maxout_pieces = 3 +subword_features = true diff --git a/spacy/ml/models/defaults/morphologizer_defaults.cfg b/spacy/ml/models/defaults/morphologizer_defaults.cfg new file mode 100644 index 000000000..80e776c4f --- /dev/null +++ b/spacy/ml/models/defaults/morphologizer_defaults.cfg @@ -0,0 +1,14 @@ +[model] +@architectures = "spacy.Tagger.v1" + +[model.tok2vec] +@architectures = "spacy.HashCharEmbedCNN.v1" +pretrained_vectors = null +width = 128 +depth = 4 +embed_size = 7000 +window_size = 1 +maxout_pieces = 3 +subword_features = true +nM = 64 +nC = 8 diff --git a/spacy/ml/models/defaults/ner_defaults.cfg b/spacy/ml/models/defaults/ner_defaults.cfg new file mode 100644 index 000000000..db2c131f5 --- /dev/null +++ b/spacy/ml/models/defaults/ner_defaults.cfg @@ -0,0 +1,15 @@ +[model] +@architectures = "spacy.TransitionBasedParser.v1" +nr_feature_tokens = 6 +hidden_width = 64 +maxout_pieces = 2 + +[model.tok2vec] +@architectures = "spacy.HashEmbedCNN.v1" +pretrained_vectors = null +width = 96 +depth = 4 +embed_size = 2000 +window_size = 1 +maxout_pieces = 3 +subword_features = true diff --git a/spacy/ml/models/defaults/parser_defaults.cfg b/spacy/ml/models/defaults/parser_defaults.cfg new file mode 100644 index 000000000..9cbb6eadb --- /dev/null +++ b/spacy/ml/models/defaults/parser_defaults.cfg @@ -0,0 +1,15 @@ +[model] +@architectures = "spacy.TransitionBasedParser.v1" +nr_feature_tokens = 8 +hidden_width = 64 +maxout_pieces = 2 + +[model.tok2vec] +@architectures = "spacy.HashEmbedCNN.v1" +pretrained_vectors = null +width = 96 +depth = 4 +embed_size = 2000 +window_size = 1 +maxout_pieces = 3 +subword_features = true diff --git a/spacy/ml/models/defaults/sentrec_defaults.cfg b/spacy/ml/models/defaults/sentrec_defaults.cfg new file mode 100644 index 000000000..a039a4533 --- /dev/null +++ b/spacy/ml/models/defaults/sentrec_defaults.cfg @@ -0,0 +1,14 @@ +[model] +@architectures = "spacy.Tagger.v1" + +[model.tok2vec] +@architectures = "spacy.HashCharEmbedCNN.v1" +pretrained_vectors = null +width = 12 +depth = 1 +embed_size = 2000 +window_size = 1 +maxout_pieces = 2 +subword_features = true +nM = 64 +nC = 8 diff --git a/spacy/ml/models/defaults/tagger_defaults.cfg b/spacy/ml/models/defaults/tagger_defaults.cfg new file mode 100644 index 000000000..5aea80a32 --- /dev/null +++ b/spacy/ml/models/defaults/tagger_defaults.cfg @@ -0,0 +1,12 @@ +[model] +@architectures = "spacy.Tagger.v1" + +[model.tok2vec] +@architectures = "spacy.HashEmbedCNN.v1" +pretrained_vectors = null +width = 96 +depth = 4 +embed_size = 2000 +window_size = 1 +maxout_pieces = 3 +subword_features = true diff --git a/spacy/ml/models/defaults/tensorizer_defaults.cfg b/spacy/ml/models/defaults/tensorizer_defaults.cfg new file mode 100644 index 000000000..81880a109 --- /dev/null +++ b/spacy/ml/models/defaults/tensorizer_defaults.cfg @@ -0,0 +1,4 @@ +[model] +@architectures = "spacy.Tensorizer.v1" +input_size=96 +output_size=300 diff --git a/spacy/ml/models/defaults/textcat_defaults.cfg b/spacy/ml/models/defaults/textcat_defaults.cfg new file mode 100644 index 000000000..cea1bfe54 --- /dev/null +++ b/spacy/ml/models/defaults/textcat_defaults.cfg @@ -0,0 +1,13 @@ +[model] +@architectures = "spacy.TextCatCNN.v1" +exclusive_classes = false + +[model.tok2vec] +@architectures = "spacy.HashEmbedCNN.v1" +pretrained_vectors = null +width = 96 +depth = 4 +embed_size = 2000 +window_size = 1 +maxout_pieces = 3 +subword_features = true diff --git a/spacy/ml/models/defaults/tok2vec_defaults.cfg b/spacy/ml/models/defaults/tok2vec_defaults.cfg new file mode 100644 index 000000000..9475d4aab --- /dev/null +++ b/spacy/ml/models/defaults/tok2vec_defaults.cfg @@ -0,0 +1,9 @@ +[model] +@architectures = "spacy.HashEmbedCNN.v1" +pretrained_vectors = null +width = 96 +depth = 4 +embed_size = 2000 +window_size = 1 +maxout_pieces = 3 +subword_features = true diff --git a/spacy/ml/models/entity_linker.py b/spacy/ml/models/entity_linker.py new file mode 100644 index 000000000..0c1762026 --- /dev/null +++ b/spacy/ml/models/entity_linker.py @@ -0,0 +1,23 @@ +from pathlib import Path + +from thinc.api import chain, clone, list2ragged, reduce_mean, residual +from thinc.api import Model, Maxout, Linear + +from spacy.util import registry + + +@registry.architectures.register("spacy.EntityLinker.v1") +def build_nel_encoder(tok2vec, nO=None): + with Model.define_operators({">>": chain, "**": clone}): + token_width = tok2vec.get_dim("nO") + output_layer = Linear(nO=nO, nI=token_width) + model = ( + tok2vec + >> list2ragged() + >> reduce_mean() + >> residual(Maxout(nO=token_width, nI=token_width, nP=2, dropout=0.0)) + >> output_layer + ) + model.set_ref("output_layer", output_layer) + model.set_ref("tok2vec", tok2vec) + return model diff --git a/spacy/ml/models/multi_task.py b/spacy/ml/models/multi_task.py new file mode 100644 index 000000000..1c193df82 --- /dev/null +++ b/spacy/ml/models/multi_task.py @@ -0,0 +1,29 @@ +from thinc.api import chain, Maxout, LayerNorm, Softmax, Linear, zero_init + + +def build_multi_task_model(n_tags, tok2vec=None, token_vector_width=96): + model = chain( + tok2vec, + Maxout(nO=token_vector_width * 2, nI=token_vector_width, nP=3, dropout=0.0), + LayerNorm(token_vector_width * 2), + Softmax(nO=n_tags, nI=token_vector_width * 2), + ) + return model + + +def build_cloze_multi_task_model(vocab, tok2vec): + output_size = vocab.vectors.data.shape[1] + output_layer = chain( + Maxout( + nO=output_size, nI=tok2vec.get_dim("nO"), nP=3, normalize=True, dropout=0.0 + ), + Linear(nO=output_size, nI=output_size, init_W=zero_init), + ) + model = chain(tok2vec, output_layer) + model = build_masked_language_model(vocab, model) + return model + + +def build_masked_language_model(*args, **kwargs): + # TODO cf https://github.com/explosion/spaCy/blob/2c107f02a4d60bda2440db0aad1a88cbbf4fb52d/spacy/_ml.py#L828 + raise NotImplementedError diff --git a/spacy/ml/models/parser.py b/spacy/ml/models/parser.py new file mode 100644 index 000000000..89f303e2a --- /dev/null +++ b/spacy/ml/models/parser.py @@ -0,0 +1,33 @@ +from pydantic import StrictInt + +from spacy.util import registry +from spacy.ml._layers import PrecomputableAffine +from spacy.syntax._parser_model import ParserModel + +from thinc.api import Model, chain, list2array, Linear, zero_init, use_ops + + +@registry.architectures.register("spacy.TransitionBasedParser.v1") +def build_tb_parser_model( + tok2vec: Model, + nr_feature_tokens: StrictInt, + hidden_width: StrictInt, + maxout_pieces: StrictInt, + nO=None, +): + token_vector_width = tok2vec.get_dim("nO") + tok2vec = chain(tok2vec, list2array()) + tok2vec.set_dim("nO", token_vector_width) + + lower = PrecomputableAffine( + nO=hidden_width, + nF=nr_feature_tokens, + nI=tok2vec.get_dim("nO"), + nP=maxout_pieces, + ) + lower.set_dim("nP", maxout_pieces) + with use_ops("numpy"): + # Initialize weights at zero, as it's a classification layer. + upper = Linear(nO=nO, init_W=zero_init) + model = ParserModel(tok2vec, lower, upper) + return model diff --git a/spacy/ml/models/tagger.py b/spacy/ml/models/tagger.py new file mode 100644 index 000000000..92e8be1b2 --- /dev/null +++ b/spacy/ml/models/tagger.py @@ -0,0 +1,16 @@ +from thinc.api import zero_init, with_array, Softmax, chain, Model + +from spacy.util import registry + + +@registry.architectures.register("spacy.Tagger.v1") +def build_tagger_model(tok2vec, nO=None) -> Model: + token_vector_width = tok2vec.get_dim("nO") + # TODO: glorot_uniform_init seems to work a bit better than zero_init here?! + output_layer = Softmax(nO, nI=token_vector_width, init_W=zero_init) + softmax = with_array(output_layer) + model = chain(tok2vec, softmax) + model.set_ref("tok2vec", tok2vec) + model.set_ref("softmax", softmax) + model.set_ref("output_layer", output_layer) + return model diff --git a/spacy/ml/models/tensorizer.py b/spacy/ml/models/tensorizer.py new file mode 100644 index 000000000..f66610b64 --- /dev/null +++ b/spacy/ml/models/tensorizer.py @@ -0,0 +1,10 @@ +from thinc.api import Linear, zero_init + +from ... import util +from ...util import registry + + +@registry.architectures.register("spacy.Tensorizer.v1") +def build_tensorizer(input_size, output_size): + input_size = util.env_opt("token_vector_width", input_size) + return Linear(output_size, input_size, init_W=zero_init) diff --git a/spacy/ml/models/textcat.py b/spacy/ml/models/textcat.py new file mode 100644 index 000000000..d9ac34b99 --- /dev/null +++ b/spacy/ml/models/textcat.py @@ -0,0 +1,42 @@ +from spacy.attrs import ORTH +from spacy.util import registry +from spacy.ml.extract_ngrams import extract_ngrams + +from thinc.api import Model, chain, reduce_mean, Linear, list2ragged, Logistic, SparseLinear, Softmax + + +@registry.architectures.register("spacy.TextCatCNN.v1") +def build_simple_cnn_text_classifier(tok2vec, exclusive_classes, nO=None): + """ + Build a simple CNN text classifier, given a token-to-vector model as inputs. + If exclusive_classes=True, a softmax non-linearity is applied, so that the + outputs sum to 1. If exclusive_classes=False, a logistic non-linearity + is applied instead, so that outputs are in the range [0, 1]. + """ + with Model.define_operators({">>": chain}): + if exclusive_classes: + output_layer = Softmax(nO=nO, nI=tok2vec.get_dim("nO")) + model = tok2vec >> list2ragged() >> reduce_mean() >> output_layer + model.set_ref("output_layer", output_layer) + else: + # TODO: experiment with init_w=zero_init + linear_layer = Linear(nO=nO, nI=tok2vec.get_dim("nO")) + model = tok2vec >> list2ragged() >> reduce_mean() >> linear_layer >> Logistic() + model.set_ref("output_layer", linear_layer) + model.set_ref("tok2vec", tok2vec) + model.set_dim("nO", nO) + return model + + +@registry.architectures.register("spacy.TextCatBOW.v1") +def build_bow_text_classifier(exclusive_classes, ngram_size, no_output_layer, nO=None): + # Note: original defaults were ngram_size=1 and no_output_layer=False + with Model.define_operators({">>": chain}): + model = extract_ngrams(ngram_size, attr=ORTH) >> SparseLinear(nO) + model.to_cpu() + if not no_output_layer: + output_layer = Softmax(nO) if exclusive_classes else Logistic(nO) + output_layer.to_cpu() + model = model >> output_layer + model.set_ref("output_layer", output_layer) + return model diff --git a/spacy/ml/models/tok2vec.py b/spacy/ml/models/tok2vec.py new file mode 100644 index 000000000..2e0e4c2d4 --- /dev/null +++ b/spacy/ml/models/tok2vec.py @@ -0,0 +1,390 @@ +from thinc.api import chain, clone, concatenate, with_array, uniqued +from thinc.api import Model, noop, with_padded, Maxout, expand_window +from thinc.api import HashEmbed, StaticVectors, PyTorchLSTM +from thinc.api import residual, LayerNorm, FeatureExtractor, Mish + +from ... import util +from ...util import registry, make_layer +from ...ml import _character_embed +from ...pipeline.tok2vec import Tok2VecListener +from ...attrs import ID, ORTH, NORM, PREFIX, SUFFIX, SHAPE + + +@registry.architectures.register("spacy.Tok2VecTensors.v1") +def tok2vec_tensors_v1(width): + tok2vec = Tok2VecListener("tok2vec", width=width) + return tok2vec + + +@registry.architectures.register("spacy.VocabVectors.v1") +def get_vocab_vectors(name): + nlp = util.load_model(name) + return nlp.vocab.vectors + + +@registry.architectures.register("spacy.Tok2Vec.v1") +def Tok2Vec(config): + doc2feats = make_layer(config["@doc2feats"]) + embed = make_layer(config["@embed"]) + encode = make_layer(config["@encode"]) + field_size = 0 + if encode.has_attr("receptive_field"): + field_size = encode.attrs["receptive_field"] + tok2vec = chain(doc2feats, with_array(chain(embed, encode), pad=field_size)) + tok2vec.attrs["cfg"] = config + tok2vec.set_dim("nO", encode.get_dim("nO")) + tok2vec.set_ref("embed", embed) + tok2vec.set_ref("encode", encode) + return tok2vec + + +@registry.architectures.register("spacy.Doc2Feats.v1") +def Doc2Feats(config): + columns = config["columns"] + return FeatureExtractor(columns) + + +@registry.architectures.register("spacy.HashEmbedCNN.v1") +def hash_embed_cnn( + pretrained_vectors, + width, + depth, + embed_size, + maxout_pieces, + window_size, + subword_features, +): + # Does not use character embeddings: set to False by default + return build_Tok2Vec_model( + width=width, + embed_size=embed_size, + pretrained_vectors=pretrained_vectors, + conv_depth=depth, + bilstm_depth=0, + maxout_pieces=maxout_pieces, + window_size=window_size, + subword_features=subword_features, + char_embed=False, + nM=0, + nC=0, + ) + + +@registry.architectures.register("spacy.HashCharEmbedCNN.v1") +def hash_charembed_cnn( + pretrained_vectors, + width, + depth, + embed_size, + maxout_pieces, + window_size, + subword_features, + nM=0, + nC=0, +): + # Allows using character embeddings by setting nC, nM and char_embed=True + return build_Tok2Vec_model( + width=width, + embed_size=embed_size, + pretrained_vectors=pretrained_vectors, + conv_depth=depth, + bilstm_depth=0, + maxout_pieces=maxout_pieces, + window_size=window_size, + subword_features=subword_features, + char_embed=True, + nM=nM, + nC=nC, + ) + + +@registry.architectures.register("spacy.HashEmbedBiLSTM.v1") +def hash_embed_bilstm_v1( + pretrained_vectors, width, depth, embed_size, subword_features +): + # Does not use character embeddings: set to False by default + return build_Tok2Vec_model( + width=width, + embed_size=embed_size, + pretrained_vectors=pretrained_vectors, + bilstm_depth=depth, + conv_depth=0, + maxout_pieces=0, + window_size=1, + subword_features=subword_features, + char_embed=False, + nM=0, + nC=0, + ) + + +@registry.architectures.register("spacy.HashCharEmbedBiLSTM.v1") +def hash_embed_bilstm_v1( + pretrained_vectors, width, depth, embed_size, subword_features, nM=0, nC=0 +): + # Allows using character embeddings by setting nC, nM and char_embed=True + return build_Tok2Vec_model( + width=width, + embed_size=embed_size, + pretrained_vectors=pretrained_vectors, + bilstm_depth=depth, + conv_depth=0, + maxout_pieces=0, + window_size=1, + subword_features=subword_features, + char_embed=True, + nM=nM, + nC=nC, + ) + + +@registry.architectures.register("spacy.MultiHashEmbed.v1") +def MultiHashEmbed(config): + # For backwards compatibility with models before the architecture registry, + # we have to be careful to get exactly the same model structure. One subtle + # trick is that when we define concatenation with the operator, the operator + # is actually binary associative. So when we write (a | b | c), we're actually + # getting concatenate(concatenate(a, b), c). That's why the implementation + # is a bit ugly here. + cols = config["columns"] + width = config["width"] + rows = config["rows"] + + norm = HashEmbed(width, rows, column=cols.index("NORM")) + if config["use_subwords"]: + prefix = HashEmbed(width, rows // 2, column=cols.index("PREFIX")) + suffix = HashEmbed(width, rows // 2, column=cols.index("SUFFIX")) + shape = HashEmbed(width, rows // 2, column=cols.index("SHAPE")) + if config.get("@pretrained_vectors"): + glove = make_layer(config["@pretrained_vectors"]) + mix = make_layer(config["@mix"]) + + with Model.define_operators({">>": chain, "|": concatenate}): + if config["use_subwords"] and config["@pretrained_vectors"]: + mix._layers[0].set_dim("nI", width * 5) + layer = uniqued( + (glove | norm | prefix | suffix | shape) >> mix, + column=cols.index("ORTH"), + ) + elif config["use_subwords"]: + mix._layers[0].set_dim("nI", width * 4) + layer = uniqued( + (norm | prefix | suffix | shape) >> mix, column=cols.index("ORTH") + ) + elif config["@pretrained_vectors"]: + mix._layers[0].set_dim("nI", width * 2) + layer = uniqued((glove | norm) >> mix, column=cols.index("ORTH")) + else: + layer = norm + layer.attrs["cfg"] = config + return layer + + +@registry.architectures.register("spacy.CharacterEmbed.v1") +def CharacterEmbed(config): + width = config["width"] + chars = config["chars"] + + chr_embed = _character_embed.CharacterEmbed(nM=width, nC=chars) + other_tables = make_layer(config["@embed_features"]) + mix = make_layer(config["@mix"]) + + model = chain(concatenate(chr_embed, other_tables), mix) + model.attrs["cfg"] = config + return model + + +@registry.architectures.register("spacy.MaxoutWindowEncoder.v1") +def MaxoutWindowEncoder(config): + nO = config["width"] + nW = config["window_size"] + nP = config["pieces"] + depth = config["depth"] + + cnn = ( + expand_window(window_size=nW), + Maxout(nO=nO, nI=nO * ((nW * 2) + 1), nP=nP, dropout=0.0, normalize=True), + ) + model = clone(residual(cnn), depth) + model.set_dim("nO", nO) + model.attrs["receptive_field"] = nW * depth + return model + + +@registry.architectures.register("spacy.MishWindowEncoder.v1") +def MishWindowEncoder(config): + nO = config["width"] + nW = config["window_size"] + depth = config["depth"] + + cnn = chain( + expand_window(window_size=nW), + Mish(nO=nO, nI=nO * ((nW * 2) + 1)), + LayerNorm(nO), + ) + model = clone(residual(cnn), depth) + model.set_dim("nO", nO) + return model + + +@registry.architectures.register("spacy.TorchBiLSTMEncoder.v1") +def TorchBiLSTMEncoder(config): + import torch.nn + + # TODO FIX + from thinc.api import PyTorchRNNWrapper + + width = config["width"] + depth = config["depth"] + if depth == 0: + return noop() + return with_padded( + PyTorchRNNWrapper(torch.nn.LSTM(width, width // 2, depth, bidirectional=True)) + ) + + +# TODO: update +_EXAMPLE_CONFIG = { + "@doc2feats": { + "arch": "Doc2Feats", + "config": {"columns": ["ID", "NORM", "PREFIX", "SUFFIX", "SHAPE", "ORTH"]}, + }, + "@embed": { + "arch": "spacy.MultiHashEmbed.v1", + "config": { + "width": 96, + "rows": 2000, + "columns": ["ID", "NORM", "PREFIX", "SUFFIX", "SHAPE", "ORTH"], + "use_subwords": True, + "@pretrained_vectors": { + "arch": "TransformedStaticVectors", + "config": { + "vectors_name": "en_vectors_web_lg.vectors", + "width": 96, + "column": 0, + }, + }, + "@mix": { + "arch": "LayerNormalizedMaxout", + "config": {"width": 96, "pieces": 3}, + }, + }, + }, + "@encode": { + "arch": "MaxoutWindowEncode", + "config": {"width": 96, "window_size": 1, "depth": 4, "pieces": 3}, + }, +} + + +def build_Tok2Vec_model( + width, + embed_size, + pretrained_vectors, + window_size, + maxout_pieces, + subword_features, + char_embed, + nM, + nC, + conv_depth, + bilstm_depth, +) -> Model: + if char_embed: + subword_features = False + cols = [ID, NORM, PREFIX, SUFFIX, SHAPE, ORTH] + with Model.define_operators({">>": chain, "|": concatenate, "**": clone}): + norm = HashEmbed(nO=width, nV=embed_size, column=cols.index(NORM)) + if subword_features: + prefix = HashEmbed(nO=width, nV=embed_size // 2, column=cols.index(PREFIX)) + suffix = HashEmbed(nO=width, nV=embed_size // 2, column=cols.index(SUFFIX)) + shape = HashEmbed(nO=width, nV=embed_size // 2, column=cols.index(SHAPE)) + else: + prefix, suffix, shape = (None, None, None) + if pretrained_vectors is not None: + glove = StaticVectors( + vectors=pretrained_vectors.data, + nO=width, + column=cols.index(ID), + dropout=0.0, + ) + + if subword_features: + columns = 5 + embed = uniqued( + (glove | norm | prefix | suffix | shape) + >> Maxout( + nO=width, + nI=width * columns, + nP=maxout_pieces, + dropout=0.0, + normalize=True, + ), + column=cols.index(ORTH), + ) + else: + columns = 2 + embed = uniqued( + (glove | norm) + >> Maxout( + nO=width, + nI=width * columns, + nP=maxout_pieces, + dropout=0.0, + normalize=True, + ), + column=cols.index(ORTH), + ) + elif subword_features: + columns = 4 + embed = uniqued( + concatenate(norm, prefix, suffix, shape) + >> Maxout( + nO=width, + nI=width * columns, + nP=maxout_pieces, + dropout=0.0, + normalize=True, + ), + column=cols.index(ORTH), + ) + elif char_embed: + embed = _character_embed.CharacterEmbed(nM=nM, nC=nC) | FeatureExtractor( + cols + ) >> with_array(norm) + reduce_dimensions = Maxout( + nO=width, + nI=nM * nC + width, + nP=maxout_pieces, + dropout=0.0, + normalize=True, + ) + else: + embed = norm + + convolution = residual( + expand_window(window_size=window_size) + >> Maxout( + nO=width, + nI=width * ((window_size * 2) + 1), + nP=maxout_pieces, + dropout=0.0, + normalize=True, + ) + ) + if char_embed: + tok2vec = embed >> with_array( + reduce_dimensions >> convolution ** conv_depth, pad=conv_depth + ) + else: + tok2vec = FeatureExtractor(cols) >> with_array( + embed >> convolution ** conv_depth, pad=conv_depth + ) + + if bilstm_depth >= 1: + tok2vec = tok2vec >> PyTorchLSTM( + nO=width, nI=width, depth=bilstm_depth, bi=True + ) + tok2vec.set_dim("nO", width) + tok2vec.set_ref("embed", embed) + return tok2vec diff --git a/spacy/ml/tok2vec.py b/spacy/ml/tok2vec.py index 5e51bc47a..e69de29bb 100644 --- a/spacy/ml/tok2vec.py +++ b/spacy/ml/tok2vec.py @@ -1,178 +0,0 @@ -from thinc.api import Model, chain, clone, concatenate, with_array, uniqued, noop -from thinc.api import with_padded, Maxout, expand_window, HashEmbed, StaticVectors -from thinc.api import residual, LayerNorm, FeatureExtractor - -from ..ml import _character_embed -from ..util import make_layer, registry - - -@registry.architectures.register("spacy.Tok2Vec.v1") -def Tok2Vec(config): - doc2feats = make_layer(config["@doc2feats"]) - embed = make_layer(config["@embed"]) - encode = make_layer(config["@encode"]) - field_size = 0 - if encode.has_attr("receptive_field"): - field_size = encode.attrs["receptive_field"] - tok2vec = chain(doc2feats, with_array(chain(embed, encode), pad=field_size)) - tok2vec.attrs["cfg"] = config - tok2vec.set_dim("nO", encode.get_dim("nO")) - tok2vec.set_ref("embed", embed) - tok2vec.set_ref("encode", encode) - return tok2vec - - -@registry.architectures.register("spacy.Doc2Feats.v1") -def Doc2Feats(config): - columns = config["columns"] - return FeatureExtractor(columns) - - -@registry.architectures.register("spacy.MultiHashEmbed.v1") -def MultiHashEmbed(config): - # For backwards compatibility with models before the architecture registry, - # we have to be careful to get exactly the same model structure. One subtle - # trick is that when we define concatenation with the operator, the operator - # is actually binary associative. So when we write (a | b | c), we're actually - # getting concatenate(concatenate(a, b), c). That's why the implementation - # is a bit ugly here. - cols = config["columns"] - width = config["width"] - rows = config["rows"] - - norm = HashEmbed(width, rows, column=cols.index("NORM"), dropout=0.0) - if config["use_subwords"]: - prefix = HashEmbed(width, rows // 2, column=cols.index("PREFIX"), dropout=0.0) - suffix = HashEmbed(width, rows // 2, column=cols.index("SUFFIX"), dropout=0.0) - shape = HashEmbed(width, rows // 2, column=cols.index("SHAPE"), dropout=0.0) - if config.get("@pretrained_vectors"): - glove = make_layer(config["@pretrained_vectors"]) - mix = make_layer(config["@mix"]) - - with Model.define_operators({">>": chain, "|": concatenate}): - if config["use_subwords"] and config["@pretrained_vectors"]: - mix._layers[0].set_dim("nI", width * 5) - layer = uniqued( - (glove | norm | prefix | suffix | shape) >> mix, - column=cols.index("ORTH"), - ) - elif config["use_subwords"]: - mix._layers[0].set_dim("nI", width * 4) - layer = uniqued( - (norm | prefix | suffix | shape) >> mix, column=cols.index("ORTH") - ) - elif config["@pretrained_vectors"]: - mix._layers[0].set_dim("nI", width * 2) - layer = uniqued((glove | norm) >> mix, column=cols.index("ORTH"),) - else: - layer = norm - layer.attrs["cfg"] = config - return layer - - -@registry.architectures.register("spacy.CharacterEmbed.v1") -def CharacterEmbed(config): - width = config["width"] - chars = config["chars"] - - chr_embed = _character_embed.CharacterEmbed(nM=width, nC=chars) - other_tables = make_layer(config["@embed_features"]) - mix = make_layer(config["@mix"]) - - model = chain(concatenate(chr_embed, other_tables), mix) - model.attrs["cfg"] = config - return model - - -@registry.architectures.register("spacy.MaxoutWindowEncoder.v1") -def MaxoutWindowEncoder(config): - nO = config["width"] - nW = config["window_size"] - nP = config["pieces"] - depth = config["depth"] - cnn = ( - expand_window(window_size=nW), - Maxout(nO=nO, nI=nO * ((nW * 2) + 1), nP=nP, dropout=0.0, normalize=True), - ) - model = clone(residual(cnn), depth) - model.set_dim("nO", nO) - model.attrs["receptive_field"] = nW * depth - return model - - -@registry.architectures.register("spacy.MishWindowEncoder.v1") -def MishWindowEncoder(config): - from thinc.api import Mish - - nO = config["width"] - nW = config["window_size"] - depth = config["depth"] - cnn = chain( - expand_window(window_size=nW), - Mish(nO=nO, nI=nO * ((nW * 2) + 1)), - LayerNorm(nO), - ) - model = clone(residual(cnn), depth) - model.set_dim("nO", nO) - return model - - -@registry.architectures.register("spacy.PretrainedVectors.v1") -def PretrainedVectors(config): - # TODO: actual vectors instead of name - return StaticVectors( - vectors=config["vectors_name"], - nO=config["width"], - column=config["column"], - dropout=0.0, - ) - - -@registry.architectures.register("spacy.TorchBiLSTMEncoder.v1") -def TorchBiLSTMEncoder(config): - import torch.nn - - # TODO: FIX - from thinc.api import PyTorchRNNWrapper - - width = config["width"] - depth = config["depth"] - if depth == 0: - return noop() - return with_padded( - PyTorchRNNWrapper(torch.nn.LSTM(width, width // 2, depth, bidirectional=True)) - ) - - -# TODO: update -_EXAMPLE_CONFIG = { - "@doc2feats": { - "arch": "Doc2Feats", - "config": {"columns": ["ID", "NORM", "PREFIX", "SUFFIX", "SHAPE", "ORTH"]}, - }, - "@embed": { - "arch": "spacy.MultiHashEmbed.v1", - "config": { - "width": 96, - "rows": 2000, - "columns": ["ID", "NORM", "PREFIX", "SUFFIX", "SHAPE", "ORTH"], - "use_subwords": True, - "@pretrained_vectors": { - "arch": "TransformedStaticVectors", - "config": { - "vectors_name": "en_vectors_web_lg.vectors", - "width": 96, - "column": 0, - }, - }, - "@mix": { - "arch": "LayerNormalizedMaxout", - "config": {"width": 96, "pieces": 3}, - }, - }, - }, - "@encode": { - "arch": "MaxoutWindowEncode", - "config": {"width": 96, "window_size": 1, "depth": 4, "pieces": 3}, - }, -} diff --git a/spacy/pipeline/entityruler.py b/spacy/pipeline/entityruler.py index e211acb44..06c568ac9 100644 --- a/spacy/pipeline/entityruler.py +++ b/spacy/pipeline/entityruler.py @@ -66,7 +66,7 @@ class EntityRuler(object): self.add_patterns(patterns) @classmethod - def from_nlp(cls, nlp, **cfg): + def from_nlp(cls, nlp, model=None, **cfg): return cls(nlp, **cfg) def __len__(self): diff --git a/spacy/pipeline/hooks.py b/spacy/pipeline/hooks.py index d48b04bd1..351323ae9 100644 --- a/spacy/pipeline/hooks.py +++ b/spacy/pipeline/hooks.py @@ -76,11 +76,9 @@ class SimilarityHook(Pipe): yield self(doc) def predict(self, doc1, doc2): - self.require_model() return self.model.predict([(doc1, doc2)]) def update(self, doc1_doc2, golds, sgd=None, drop=0.0): - self.require_model() sims, bp_sims = self.model.begin_update(doc1_doc2) def begin_training(self, _=tuple(), pipeline=None, sgd=None, **kwargs): diff --git a/spacy/pipeline/morphologizer.pyx b/spacy/pipeline/morphologizer.pyx index 999132b35..b6a6045d1 100644 --- a/spacy/pipeline/morphologizer.pyx +++ b/spacy/pipeline/morphologizer.pyx @@ -15,25 +15,15 @@ from ..tokens.doc cimport Doc from ..vocab cimport Vocab from ..morphology cimport Morphology -from ..ml.component_models import build_morphologizer_model - @component("morphologizer", assigns=["token.morph", "token.pos"]) class Morphologizer(Pipe): - @classmethod - def Model(cls, **cfg): - if cfg.get('pretrained_dims') and not cfg.get('pretrained_vectors'): - raise ValueError(TempErrors.T008) - class_map = Morphology.create_class_map() - return build_morphologizer_model(class_map.field_sizes, **cfg) - - def __init__(self, vocab, model=True, **cfg): + def __init__(self, vocab, model, **cfg): self.vocab = vocab self.model = model self.cfg = dict(sorted(cfg.items())) - self.cfg.setdefault('cnn_maxout_pieces', 2) - self._class_map = self.vocab.morphology.create_class_map() + self._class_map = self.vocab.morphology.create_class_map() # Morphology.create_class_map() ? @property def labels(self): @@ -58,6 +48,14 @@ class Morphologizer(Pipe): self.set_annotations(docs, features, tensors=tokvecs) yield from docs + def begin_training(self, get_examples=lambda: [], pipeline=None, sgd=None, + **kwargs): + self.set_output(len(self.labels)) + self.model.initialize() + if sgd is None: + sgd = self.create_optimizer() + return sgd + def predict(self, docs): if not any(len(doc) for doc in docs): # Handle case where there are no tokens in any docs. @@ -65,8 +63,8 @@ class Morphologizer(Pipe): guesses = [self.model.ops.alloc((0, n_labels)) for doc in docs] tokvecs = self.model.ops.alloc((0, self.model.get_ref("tok2vec").get_dim("nO"))) return guesses, tokvecs - tokvecs = self.model.tok2vec(docs) - scores = self.model.softmax(tokvecs) + tokvecs = self.model.get_ref("tok2vec")(docs) + scores = self.model.get_ref("softmax")(tokvecs) return scores, tokvecs def set_annotations(self, docs, batch_scores, tensors=None): diff --git a/spacy/pipeline/pipes.pyx b/spacy/pipeline/pipes.pyx index ad75d2e78..b9bf1ccd6 100644 --- a/spacy/pipeline/pipes.pyx +++ b/spacy/pipeline/pipes.pyx @@ -3,8 +3,7 @@ import numpy import srsly import random -from thinc.api import chain, Linear, Maxout, Softmax, LayerNorm, list2array -from thinc.api import zero_init, CosineDistance, to_categorical, get_array_module +from thinc.api import CosineDistance, to_categorical, get_array_module from thinc.api import set_dropout_rate from ..tokens.doc cimport Doc @@ -22,11 +21,6 @@ from ..attrs import POS, ID from ..util import link_vectors_to_models, create_default_optimizer from ..parts_of_speech import X from ..kb import KnowledgeBase -from ..ml.component_models import Tok2Vec, build_tagger_model -from ..ml.component_models import build_text_classifier -from ..ml.component_models import build_simple_cnn_text_classifier -from ..ml.component_models import build_bow_text_classifier, build_nel_encoder -from ..ml.component_models import masked_language_model from ..errors import Errors, TempErrors, user_warning, Warnings from .. import util @@ -47,13 +41,8 @@ class Pipe(object): name = None @classmethod - def Model(cls, *shape, **kwargs): - """Initialize a model for the pipe.""" - raise NotImplementedError - - @classmethod - def from_nlp(cls, nlp, **cfg): - return cls(nlp.vocab, **cfg) + def from_nlp(cls, nlp, model, **cfg): + return cls(nlp.vocab, model, **cfg) def _get_doc(self, example): """ Use this method if the `example` can be both a Doc or an Example """ @@ -61,7 +50,7 @@ class Pipe(object): return example return example.doc - def __init__(self, vocab, model=True, **cfg): + def __init__(self, vocab, model, **cfg): """Create a new pipe instance.""" raise NotImplementedError @@ -72,7 +61,6 @@ class Pipe(object): Both __call__ and pipe should delegate to the `predict()` and `set_annotations()` methods. """ - self.require_model() doc = self._get_doc(example) predictions = self.predict([doc]) if isinstance(predictions, tuple) and len(predictions) == 2: @@ -85,11 +73,6 @@ class Pipe(object): return example return doc - def require_model(self): - """Raise an error if the component's model is not initialized.""" - if getattr(self, "model", None) in (None, True, False): - raise ValueError(Errors.E109.format(name=self.name)) - def pipe(self, stream, batch_size=128, n_threads=-1, as_example=False): """Apply the pipe to a stream of documents. @@ -116,7 +99,6 @@ class Pipe(object): """Apply the pipeline's model to a batch of docs, without modifying them. """ - self.require_model() raise NotImplementedError def set_annotations(self, docs, scores, tensors=None): @@ -158,22 +140,23 @@ class Pipe(object): ): """Initialize the pipe for training, using data exampes if available. If no model has been initialized yet, the model is added.""" - if self.model is True: - self.model = self.Model(**self.cfg) + self.model.initialize() if hasattr(self, "vocab"): link_vectors_to_models(self.vocab) - self.model.initialize() if sgd is None: sgd = self.create_optimizer() return sgd + def set_output(self, nO): + self.model.set_dim("nO", nO) + if self.model.has_ref("output_layer"): + self.model.get_ref("output_layer").set_dim("nO", nO) + def get_gradients(self): """Get non-zero gradients of the model's parameters, as a dictionary keyed by the parameter ID. The values are (weights, gradients) tuples. """ gradients = {} - if self.model in (None, True, False): - return gradients queue = [self.model] seen = set() for node in queue: @@ -199,8 +182,7 @@ class Pipe(object): """ serialize = {} serialize["cfg"] = lambda: srsly.json_dumps(self.cfg) - if self.model not in (True, False, None): - serialize["model"] = self.model.to_bytes + serialize["model"] = self.model.to_bytes if hasattr(self, "vocab"): serialize["vocab"] = self.vocab.to_bytes exclude = util.get_serialization_exclude(serialize, exclude, kwargs) @@ -210,20 +192,15 @@ class Pipe(object): """Load the pipe from a bytestring.""" def load_model(b): - # TODO: Remove this once we don't have to handle previous models - if self.cfg.get("pretrained_dims") and "pretrained_vectors" not in self.cfg: - self.cfg["pretrained_vectors"] = self.vocab.vectors - if self.model is True: - self.model = self.Model(**self.cfg) try: self.model.from_bytes(b) except AttributeError: raise ValueError(Errors.E149) deserialize = {} - deserialize["cfg"] = lambda b: self.cfg.update(srsly.json_loads(b)) if hasattr(self, "vocab"): deserialize["vocab"] = lambda b: self.vocab.from_bytes(b) + deserialize["cfg"] = lambda b: self.cfg.update(srsly.json_loads(b)) deserialize["model"] = load_model exclude = util.get_serialization_exclude(deserialize, exclude, kwargs) util.from_bytes(bytes_data, deserialize, exclude) @@ -234,8 +211,7 @@ class Pipe(object): serialize = {} serialize["cfg"] = lambda p: srsly.write_json(p, self.cfg) serialize["vocab"] = lambda p: self.vocab.to_disk(p) - if self.model not in (None, True, False): - serialize["model"] = lambda p: p.open("wb").write(self.model.to_bytes()) + serialize["model"] = lambda p: p.open("wb").write(self.model.to_bytes()) exclude = util.get_serialization_exclude(serialize, exclude, kwargs) util.to_disk(path, serialize, exclude) @@ -243,19 +219,14 @@ class Pipe(object): """Load the pipe from disk.""" def load_model(p): - # TODO: Remove this once we don't have to handle previous models - if self.cfg.get("pretrained_dims") and "pretrained_vectors" not in self.cfg: - self.cfg["pretrained_vectors"] = self.vocab.vectors - if self.model is True: - self.model = self.Model(**self.cfg) try: self.model.from_bytes(p.open("rb").read()) except AttributeError: raise ValueError(Errors.E149) deserialize = {} - deserialize["cfg"] = lambda p: self.cfg.update(_load_cfg(p)) deserialize["vocab"] = lambda p: self.vocab.from_disk(p) + deserialize["cfg"] = lambda p: self.cfg.update(_load_cfg(p)) deserialize["model"] = load_model exclude = util.get_serialization_exclude(deserialize, exclude, kwargs) util.from_disk(path, deserialize, exclude) @@ -266,31 +237,13 @@ class Pipe(object): class Tensorizer(Pipe): """Pre-train position-sensitive vectors for tokens.""" - @classmethod - def Model(cls, output_size=300, **cfg): - """Create a new statistical model for the class. - - width (int): Output size of the model. - embed_size (int): Number of vectors in the embedding table. - **cfg: Config parameters. - RETURNS (Model): A `thinc.model.Model` or similar instance. - """ - input_size = util.env_opt("token_vector_width", cfg.get("input_size", 96)) - return Linear(output_size, input_size, init_W=zero_init) - - def __init__(self, vocab, model=True, **cfg): + def __init__(self, vocab, model, **cfg): """Construct a new statistical model. Weights are not allocated on initialisation. vocab (Vocab): A `Vocab` instance. The model must share the same `Vocab` instance with the `Doc` objects it will process. - model (Model): A `Model` instance or `True` to allocate one later. **cfg: Config parameters. - - EXAMPLE: - >>> from spacy.pipeline import TokenVectorEncoder - >>> tok2vec = TokenVectorEncoder(nlp.vocab) - >>> tok2vec.model = tok2vec.Model(128, 5000) """ self.vocab = vocab self.model = model @@ -337,7 +290,6 @@ class Tensorizer(Pipe): docs (iterable): A sequence of `Doc` objects. RETURNS (object): Vector representations for each token in the docs. """ - self.require_model() inputs = self.model.ops.flatten([doc.tensor for doc in docs]) outputs = self.model(inputs) return self.model.ops.unflatten(outputs, [len(d) for d in docs]) @@ -362,7 +314,6 @@ class Tensorizer(Pipe): sgd (callable): An optimizer. RETURNS (dict): Results from the update. """ - self.require_model() examples = Example.to_example_objects(examples) inputs = [] bp_inputs = [] @@ -405,10 +356,8 @@ class Tensorizer(Pipe): """ if pipeline is not None: for name, model in pipeline: - if getattr(model, "tok2vec", None): - self.input_models.append(model.tok2vec) - if self.model is True: - self.model = self.Model(**self.cfg) + if model.has_ref("tok2vec"): + self.input_models.append(model.get_ref("tok2vec")) self.model.initialize() link_vectors_to_models(self.vocab) if sgd is None: @@ -423,7 +372,7 @@ class Tagger(Pipe): DOCS: https://spacy.io/api/tagger """ - def __init__(self, vocab, model=True, **cfg): + def __init__(self, vocab, model, **cfg): self.vocab = vocab self.model = model self._rehearsal_model = None @@ -433,13 +382,6 @@ class Tagger(Pipe): def labels(self): return tuple(self.vocab.morphology.tag_names) - @property - def tok2vec(self): - if self.model in (None, True, False): - return None - else: - return chain(self.model.get_ref("tok2vec"), list2array()) - def __call__(self, example): doc = self._get_doc(example) tags = self.predict([doc]) @@ -465,7 +407,6 @@ class Tagger(Pipe): yield from docs def predict(self, docs): - self.require_model() if not any(len(doc) for doc in docs): # Handle cases where there are no tokens in any docs. n_labels = len(self.labels) @@ -513,7 +454,6 @@ class Tagger(Pipe): doc.is_tagged = True def update(self, examples, drop=0., sgd=None, losses=None, set_annotations=False): - self.require_model() examples = Example.to_example_objects(examples) if losses is not None and self.name not in losses: losses[self.name] = 0. @@ -600,52 +540,21 @@ class Tagger(Pipe): vocab.morphology = Morphology(vocab.strings, new_tag_map, vocab.morphology.lemmatizer, exc=vocab.morphology.exc) - self.cfg["pretrained_vectors"] = kwargs.get("pretrained_vectors") - if self.model is True: - for hp in ["token_vector_width", "conv_depth"]: - if hp in kwargs: - self.cfg[hp] = kwargs[hp] - self.model = self.Model(self.vocab.morphology.n_tags, **self.cfg) + self.set_output(len(self.labels)) + self.model.initialize() # Get batch of example docs, example outputs to call begin_training(). # This lets the model infer shapes. - n_tags = self.vocab.morphology.n_tags - for node in self.model.walk(): - # TODO: softmax hack ? - if node.name == "softmax" and node.has_dim("nO") is None: - node.set_dim("nO", n_tags) link_vectors_to_models(self.vocab) - self.model.initialize() if sgd is None: sgd = self.create_optimizer() return sgd - @classmethod - def Model(cls, n_tags=None, **cfg): - if cfg.get("pretrained_dims") and not cfg.get("pretrained_vectors"): - raise ValueError(TempErrors.T008) - if "tok2vec" in cfg: - tok2vec = cfg["tok2vec"] - else: - config = { - "width": cfg.get("token_vector_width", 96), - "embed_size": cfg.get("embed_size", 2000), - "pretrained_vectors": cfg.get("pretrained_vectors", None), - "window_size": cfg.get("window_size", 1), - "cnn_maxout_pieces": cfg.get("cnn_maxout_pieces", 3), - "subword_features": cfg.get("subword_features", True), - "char_embed": cfg.get("char_embed", False), - "conv_depth": cfg.get("conv_depth", 4), - "bilstm_depth": cfg.get("bilstm_depth", 0), - } - tok2vec = Tok2Vec(**config) - return build_tagger_model(n_tags, tok2vec) - def add_label(self, label, values=None): if not isinstance(label, str): raise ValueError(Errors.E187) if label in self.labels: return 0 - if self.model not in (True, False, None): + if self.model.has_dim("nO"): # Here's how the model resizing will work, once the # neuron-to-tag mapping is no longer controlled by # the Morphology class, which sorts the tag names. @@ -672,8 +581,7 @@ class Tagger(Pipe): def to_bytes(self, exclude=tuple(), **kwargs): serialize = {} - if self.model not in (None, True, False): - serialize["model"] = self.model.to_bytes + serialize["model"] = self.model.to_bytes serialize["vocab"] = self.vocab.to_bytes serialize["cfg"] = lambda: srsly.json_dumps(self.cfg) tag_map = dict(sorted(self.vocab.morphology.tag_map.items())) @@ -683,14 +591,6 @@ class Tagger(Pipe): def from_bytes(self, bytes_data, exclude=tuple(), **kwargs): def load_model(b): - # TODO: Remove this once we don't have to handle previous models - if self.cfg.get("pretrained_dims") and "pretrained_vectors" not in self.cfg: - self.cfg["pretrained_vectors"] = self.vocab.vectors - if self.model is True: - token_vector_width = util.env_opt( - "token_vector_width", - self.cfg.get("token_vector_width", 96)) - self.model = self.Model(**self.cfg) try: self.model.from_bytes(b) except AttributeError: @@ -719,18 +619,13 @@ class Tagger(Pipe): "vocab": lambda p: self.vocab.to_disk(p), "tag_map": lambda p: srsly.write_msgpack(p, tag_map), "model": lambda p: p.open("wb").write(self.model.to_bytes()), - "cfg": lambda p: srsly.write_json(p, self.cfg) + "cfg": lambda p: srsly.write_json(p, self.cfg), } exclude = util.get_serialization_exclude(serialize, exclude, kwargs) util.to_disk(path, serialize, exclude) def from_disk(self, path, exclude=tuple(), **kwargs): def load_model(p): - # TODO: Remove this once we don't have to handle previous models - if self.cfg.get("pretrained_dims") and "pretrained_vectors" not in self.cfg: - self.cfg["pretrained_vectors"] = self.vocab.vectors - if self.model is True: - self.model = self.Model(**self.cfg) with p.open("rb") as file_: try: self.model.from_bytes(file_.read()) @@ -745,8 +640,8 @@ class Tagger(Pipe): exc=self.vocab.morphology.exc) deserialize = { - "cfg": lambda p: self.cfg.update(_load_cfg(p)), "vocab": lambda p: self.vocab.from_disk(p), + "cfg": lambda p: self.cfg.update(_load_cfg(p)), "tag_map": load_tag_map, "model": load_model, } @@ -762,16 +657,11 @@ class SentenceRecognizer(Tagger): DOCS: https://spacy.io/api/sentencerecognizer """ - def __init__(self, vocab, model=True, **cfg): + def __init__(self, vocab, model, **cfg): self.vocab = vocab self.model = model self._rehearsal_model = None self.cfg = dict(sorted(cfg.items())) - self.cfg.setdefault("cnn_maxout_pieces", 2) - self.cfg.setdefault("subword_features", True) - self.cfg.setdefault("token_vector_width", 12) - self.cfg.setdefault("conv_depth", 1) - self.cfg.setdefault("pretrained_vectors", None) @property def labels(self): @@ -797,7 +687,6 @@ class SentenceRecognizer(Tagger): doc.c[j].sent_start = -1 def update(self, examples, drop=0., sgd=None, losses=None): - self.require_model() examples = Example.to_example_objects(examples) if losses is not None and self.name not in losses: losses[self.name] = 0. @@ -844,20 +733,12 @@ class SentenceRecognizer(Tagger): def begin_training(self, get_examples=lambda: [], pipeline=None, sgd=None, **kwargs): cdef Vocab vocab = self.vocab - if self.model is True: - for hp in ["token_vector_width", "conv_depth"]: - if hp in kwargs: - self.cfg[hp] = kwargs[hp] - self.model = self.Model(len(self.labels), **self.cfg) + self.set_output(len(self.labels)) + self.model.initialize() if sgd is None: sgd = self.create_optimizer() - self.model.initialize() return sgd - @classmethod - def Model(cls, n_tags, **cfg): - return build_tagger_model(n_tags, **cfg) - def add_label(self, label, values=None): raise NotImplementedError @@ -867,8 +748,7 @@ class SentenceRecognizer(Tagger): def to_bytes(self, exclude=tuple(), **kwargs): serialize = {} - if self.model not in (None, True, False): - serialize["model"] = self.model.to_bytes + serialize["model"] = self.model.to_bytes serialize["vocab"] = self.vocab.to_bytes serialize["cfg"] = lambda: srsly.json_dumps(self.cfg) exclude = util.get_serialization_exclude(serialize, exclude, kwargs) @@ -876,8 +756,6 @@ class SentenceRecognizer(Tagger): def from_bytes(self, bytes_data, exclude=tuple(), **kwargs): def load_model(b): - if self.model is True: - self.model = self.Model(len(self.labels), **self.cfg) try: self.model.from_bytes(b) except AttributeError: @@ -896,15 +774,13 @@ class SentenceRecognizer(Tagger): serialize = { "vocab": lambda p: self.vocab.to_disk(p), "model": lambda p: p.open("wb").write(self.model.to_bytes()), - "cfg": lambda p: srsly.write_json(p, self.cfg) + "cfg": lambda p: srsly.write_json(p, self.cfg), } exclude = util.get_serialization_exclude(serialize, exclude, kwargs) util.to_disk(path, serialize, exclude) def from_disk(self, path, exclude=tuple(), **kwargs): def load_model(p): - if self.model is True: - self.model = self.Model(len(self.labels), **self.cfg) with p.open("rb") as file_: try: self.model.from_bytes(file_.read()) @@ -912,8 +788,8 @@ class SentenceRecognizer(Tagger): raise ValueError(Errors.E149) deserialize = { - "cfg": lambda p: self.cfg.update(_load_cfg(p)), "vocab": lambda p: self.vocab.from_disk(p), + "cfg": lambda p: self.cfg.update(_load_cfg(p)), "model": load_model, } exclude = util.get_serialization_exclude(deserialize, exclude, kwargs) @@ -927,7 +803,7 @@ class MultitaskObjective(Tagger): side-objective. """ - def __init__(self, vocab, model=True, target='dep_tag_offset', **cfg): + def __init__(self, vocab, model, target='dep_tag_offset', **cfg): self.vocab = vocab self.model = model if target == "dep": @@ -947,7 +823,8 @@ class MultitaskObjective(Tagger): else: raise ValueError(Errors.E016) self.cfg = dict(cfg) - self.cfg.setdefault("cnn_maxout_pieces", 2) + # TODO: remove - put in config + self.cfg.setdefault("maxout_pieces", 2) @property def labels(self): @@ -969,30 +846,15 @@ class MultitaskObjective(Tagger): label = self.make_label(i, example.token_annotation) if label is not None and label not in self.labels: self.labels[label] = len(self.labels) - if self.model is True: - token_vector_width = util.env_opt("token_vector_width") - self.model = self.Model(len(self.labels), tok2vec=tok2vec) - link_vectors_to_models(self.vocab) self.model.initialize() + link_vectors_to_models(self.vocab) if sgd is None: sgd = self.create_optimizer() return sgd - @classmethod - def Model(cls, n_tags, tok2vec=None, **cfg): - token_vector_width = util.env_opt("token_vector_width", 96) - model = chain( - tok2vec, - Maxout(nO=token_vector_width*2, nI=token_vector_width, nP=3, dropout=0.0), - LayerNorm(token_vector_width*2), - Softmax(nO=n_tags, nI=token_vector_width*2) - ) - return model - def predict(self, docs): - self.require_model() - tokvecs = self.model.tok2vec(docs) - scores = self.model.softmax(tokvecs) + tokvecs = self.model.get_ref("tok2vec")(docs) + scores = self.model.get_ref("softmax")(tokvecs) return tokvecs, scores def get_loss(self, examples, scores): @@ -1097,18 +959,7 @@ class MultitaskObjective(Tagger): class ClozeMultitask(Pipe): - @classmethod - def Model(cls, vocab, tok2vec, **cfg): - output_size = vocab.vectors.data.shape[1] - output_layer = chain( - Maxout(nO=output_size, nI=tok2vec.get_dim("nO"), nP=3, normalize=True, dropout=0.0), - Linear(nO=output_size, nI=output_size, init_W=zero_init) - ) - model = chain(tok2vec, output_layer) - model = masked_language_model(vocab, model) - return model - - def __init__(self, vocab, model=True, **cfg): + def __init__(self, vocab, model, **cfg): self.vocab = vocab self.model = model self.cfg = cfg @@ -1120,19 +971,16 @@ class ClozeMultitask(Pipe): def begin_training(self, get_examples=lambda: [], pipeline=None, tok2vec=None, sgd=None, **kwargs): link_vectors_to_models(self.vocab) - if self.model is True: - self.model = self.Model(self.vocab, tok2vec) - X = self.model.ops.alloc((5, self.model.get_ref("tok2vec").get_dim("nO"))) self.model.initialize() + X = self.model.ops.alloc((5, self.model.get_ref("tok2vec").get_dim("nO"))) self.model.output_layer.begin_training(X) if sgd is None: sgd = self.create_optimizer() return sgd def predict(self, docs): - self.require_model() - tokvecs = self.model.tok2vec(docs) - vectors = self.model.output_layer(tokvecs) + tokvecs = self.model.get_ref("tok2vec")(docs) + vectors = self.model.get_ref("output_layer")(tokvecs) return tokvecs, vectors def get_loss(self, examples, vectors, prediction): @@ -1150,7 +998,6 @@ class ClozeMultitask(Pipe): pass def rehearse(self, examples, drop=0., sgd=None, losses=None): - self.require_model() examples = Example.to_example_objects(examples) if losses is not None and self.name not in losses: losses[self.name] = 0. @@ -1171,62 +1018,11 @@ class TextCategorizer(Pipe): DOCS: https://spacy.io/api/textcategorizer """ - - @classmethod - def Model(cls, nr_class=1, exclusive_classes=None, **cfg): - if nr_class == 1: - exclusive_classes = False - if exclusive_classes is None: - raise ValueError( - "TextCategorizer Model must specify 'exclusive_classes'. " - "This setting determines whether the model will output " - "scores that sum to 1 for each example. If only one class " - "is true for each example, you should set exclusive_classes=True. " - "For 'multi_label' classification, set exclusive_classes=False." - ) - if "embed_size" not in cfg: - cfg["embed_size"] = util.env_opt("embed_size", 2000) - if "token_vector_width" not in cfg: - cfg["token_vector_width"] = util.env_opt("token_vector_width", 96) - if cfg.get("architecture") == "bow": - return build_bow_text_classifier(nr_class, exclusive_classes, **cfg) - else: - if "tok2vec" in cfg: - tok2vec = cfg["tok2vec"] - else: - config = { - "width": cfg.get("token_vector_width", 96), - "embed_size": cfg.get("embed_size", 2000), - "pretrained_vectors": cfg.get("pretrained_vectors", None), - "window_size": cfg.get("window_size", 1), - "cnn_maxout_pieces": cfg.get("cnn_maxout_pieces", 3), - "subword_features": cfg.get("subword_features", True), - "char_embed": cfg.get("char_embed", False), - "conv_depth": cfg.get("conv_depth", 4), - "bilstm_depth": cfg.get("bilstm_depth", 0), - } - tok2vec = Tok2Vec(**config) - return build_simple_cnn_text_classifier( - tok2vec, - nr_class, - exclusive_classes, - **cfg - ) - - @property - def tok2vec(self): - if self.model in (None, True, False): - return None - else: - return self.model.tok2vec - - def __init__(self, vocab, model=True, **cfg): + def __init__(self, vocab, model, **cfg): self.vocab = vocab self.model = model self._rehearsal_model = None self.cfg = dict(cfg) - if "exclusive_classes" not in cfg: - self.cfg["exclusive_classes"] = True @property def labels(self): @@ -1255,7 +1051,6 @@ class TextCategorizer(Pipe): yield from docs def predict(self, docs): - self.require_model() tensors = [doc.tensor for doc in docs] if not any(len(doc) for doc in docs): @@ -1274,7 +1069,6 @@ class TextCategorizer(Pipe): doc.cats[label] = float(scores[i, j]) def update(self, examples, state=None, drop=0., set_annotations=False, sgd=None, losses=None): - self.require_model() examples = Example.to_example_objects(examples) if not any(len(ex.doc) if ex.doc else 0 for ex in examples): # Handle cases where there are no tokens in any docs. @@ -1311,7 +1105,7 @@ class TextCategorizer(Pipe): losses.setdefault(self.name, 0.0) losses[self.name] += (gradient**2).sum() - def get_loss(self, examples, scores): + def _examples_to_truth(self, examples): golds = [ex.gold for ex in examples] truths = numpy.zeros((len(golds), len(self.labels)), dtype="f") not_missing = numpy.ones((len(golds), len(self.labels)), dtype="f") @@ -1322,6 +1116,10 @@ class TextCategorizer(Pipe): else: not_missing[i, j] = 0. truths = self.model.ops.asarray(truths) + return truths, not_missing + + def get_loss(self, examples, scores): + truths, not_missing = self._examples_to_truth(examples) not_missing = self.model.ops.asarray(not_missing) d_scores = (scores-truths) / scores.shape[0] d_scores *= not_missing @@ -1333,7 +1131,7 @@ class TextCategorizer(Pipe): raise ValueError(Errors.E187) if label in self.labels: return 0 - if self.model not in (None, True, False): + if self.model.has_dim("nO"): # This functionality was available previously, but was broken. # The problem is that we resize the last layer, but the last layer # is actually just an ensemble. We're not resizing the child layers @@ -1348,19 +1146,18 @@ class TextCategorizer(Pipe): return 1 def begin_training(self, get_examples=lambda: [], pipeline=None, sgd=None, **kwargs): - for example in get_examples(): + # TODO: begin_training is not guaranteed to see all data / labels ? + examples = list(get_examples()) + for example in examples: for cat in example.doc_annotation.cats: self.add_label(cat) - if self.model is True: - self.cfg.update(kwargs) - self.require_labels() - self.model = self.Model(len(self.labels), **self.cfg) - link_vectors_to_models(self.vocab) + self.require_labels() + docs = [Doc(Vocab(), words=["hello"])] + truths, _ = self._examples_to_truth(examples) + self.set_output(len(self.labels)) + self.model.initialize(X=docs, Y=truths) if sgd is None: sgd = self.create_optimizer() - # TODO: use get_examples instead - docs = [Doc(Vocab(), words=["hello"])] - self.model.initialize(X=docs) return sgd @@ -1393,7 +1190,7 @@ cdef class DependencyParser(Parser): def init_multitask_objectives(self, get_examples, pipeline, sgd=None, **cfg): for labeller in self._multitasks: - tok2vec = self.model.tok2vec + tok2vec = self.model.get_ref("tok2vec") labeller.begin_training(get_examples, pipeline=pipeline, tok2vec=tok2vec, sgd=sgd) @@ -1423,7 +1220,6 @@ cdef class EntityRecognizer(Parser): assigns = ["doc.ents", "token.ent_iob", "token.ent_type"] requires = [] TransitionSystem = BiluoPushDown - nr_feature = 6 def add_multitask_objective(self, target): if target == "cloze": @@ -1435,7 +1231,7 @@ cdef class EntityRecognizer(Parser): def init_multitask_objectives(self, get_examples, pipeline, sgd=None, **cfg): for labeller in self._multitasks: - tok2vec = self.model.tok2vec + tok2vec = self.model.get_ref("tok2vec") labeller.begin_training(get_examples, pipeline=pipeline, tok2vec=tok2vec) @@ -1464,18 +1260,9 @@ class EntityLinker(Pipe): """ NIL = "NIL" # string used to refer to a non-existing link - @classmethod - def Model(cls, **cfg): - embed_width = cfg.get("embed_width", 300) - hidden_width = cfg.get("hidden_width", 128) - type_to_int = cfg.get("type_to_int", dict()) - - model = build_nel_encoder(embed_width=embed_width, hidden_width=hidden_width, ner_types=len(type_to_int), **cfg) - return model - - def __init__(self, vocab, **cfg): + def __init__(self, vocab, model, **cfg): self.vocab = vocab - self.model = True + self.model = model self.kb = None self.cfg = dict(cfg) self.distance = CosineDistance(normalize=False) @@ -1483,11 +1270,6 @@ class EntityLinker(Pipe): def set_kb(self, kb): self.kb = kb - def require_model(self): - # Raise an error if the component's model is not initialized. - if getattr(self, "model", None) in (None, True, False): - raise ValueError(Errors.E109.format(name=self.name)) - def require_kb(self): # Raise an error if the knowledge base is not initialized. if getattr(self, "kb", None) in (None, True, False): @@ -1495,16 +1277,14 @@ class EntityLinker(Pipe): def begin_training(self, get_examples=lambda: [], pipeline=None, sgd=None, **kwargs): self.require_kb() - self.cfg["entity_width"] = self.kb.entity_vector_length - if self.model is True: - self.model = self.Model(**self.cfg) + nO = self.kb.entity_vector_length + self.set_output(nO) self.model.initialize() if sgd is None: sgd = self.create_optimizer() return sgd def update(self, examples, state=None, set_annotations=False, drop=0.0, sgd=None, losses=None): - self.require_model() self.require_kb() if losses is not None: losses.setdefault(self.name, 0.0) @@ -1614,7 +1394,6 @@ class EntityLinker(Pipe): def predict(self, docs): """ Return the KB IDs for each entity in each doc, including NIL if there is no prediction """ - self.require_model() self.require_kb() entity_count = 0 @@ -1714,15 +1493,12 @@ class EntityLinker(Pipe): serialize["cfg"] = lambda p: srsly.write_json(p, self.cfg) serialize["vocab"] = lambda p: self.vocab.to_disk(p) serialize["kb"] = lambda p: self.kb.dump(p) - if self.model not in (None, True, False): - serialize["model"] = lambda p: p.open("wb").write(self.model.to_bytes()) + serialize["model"] = lambda p: p.open("wb").write(self.model.to_bytes()) exclude = util.get_serialization_exclude(serialize, exclude, kwargs) util.to_disk(path, serialize, exclude) def from_disk(self, path, exclude=tuple(), **kwargs): def load_model(p): - if self.model is True: - self.model = self.Model(**self.cfg) try: self.model.from_bytes(p.open("rb").read()) except AttributeError: @@ -1734,8 +1510,8 @@ class EntityLinker(Pipe): self.set_kb(kb) deserialize = {} - deserialize["cfg"] = lambda p: self.cfg.update(_load_cfg(p)) deserialize["vocab"] = lambda p: self.vocab.from_disk(p) + deserialize["cfg"] = lambda p: self.cfg.update(_load_cfg(p)) deserialize["kb"] = load_kb deserialize["model"] = load_model exclude = util.get_serialization_exclude(deserialize, exclude, kwargs) @@ -1782,7 +1558,7 @@ class Sentencizer(Pipe): self.punct_chars = set(self.default_punct_chars) @classmethod - def from_nlp(cls, nlp, **cfg): + def from_nlp(cls, nlp, model=None, **cfg): return cls(**cfg) def __call__(self, example): @@ -1915,8 +1691,8 @@ class Sentencizer(Pipe): # Cython classes can't be decorated, so we need to add the factories here -Language.factories["parser"] = lambda nlp, **cfg: DependencyParser.from_nlp(nlp, **cfg) -Language.factories["ner"] = lambda nlp, **cfg: EntityRecognizer.from_nlp(nlp, **cfg) +Language.factories["parser"] = lambda nlp, model, **cfg: DependencyParser.from_nlp(nlp, model, **cfg) +Language.factories["ner"] = lambda nlp, model, **cfg: EntityRecognizer.from_nlp(nlp, model, **cfg) __all__ = ["Tagger", "DependencyParser", "EntityRecognizer", "Tensorizer", "TextCategorizer", "EntityLinker", "Sentencizer", "SentenceRecognizer"] diff --git a/spacy/pipeline/tok2vec.py b/spacy/pipeline/tok2vec.py index 8290468cf..a49f94ca3 100644 --- a/spacy/pipeline/tok2vec.py +++ b/spacy/pipeline/tok2vec.py @@ -5,32 +5,21 @@ from ..gold import Example from ..tokens import Doc from ..vocab import Vocab from ..language import component -from ..util import link_vectors_to_models, minibatch, registry, eg2doc +from ..util import link_vectors_to_models, minibatch, eg2doc @component("tok2vec", assigns=["doc.tensor"]) class Tok2Vec(Pipe): - @classmethod - def from_nlp(cls, nlp, **cfg): - return cls(nlp.vocab, **cfg) @classmethod - def Model(cls, architecture, **cfg): - """Create a new statistical model for the class. + def from_nlp(cls, nlp, model, **cfg): + return cls(nlp.vocab, model, **cfg) - architecture (str): The registered model architecture to use. - **cfg: Config parameters. - RETURNS (Model): A `thinc.model.Model` or similar instance. - """ - model = registry.architectures.get(architecture) - return model(**cfg) - - def __init__(self, vocab, model=True, **cfg): + def __init__(self, vocab, model, **cfg): """Construct a new statistical model. Weights are not allocated on initialisation. vocab (Vocab): A `Vocab` instance. The model must share the same `Vocab` instance with the `Doc` objects it will process. - model (Model): A `Model` instance or `True` to allocate one later. **cfg: Config parameters. """ self.vocab = vocab @@ -143,8 +132,6 @@ class Tok2Vec(Pipe): get_examples (function): Function returning example training data. pipeline (list): The pipeline the model is part of. """ - if self.model is True: - self.model = self.Model(**self.cfg) # TODO: use examples instead ? docs = [Doc(Vocab(), words=["hello"])] self.model.initialize(X=docs) diff --git a/spacy/syntax/_parser_model.pyx b/spacy/syntax/_parser_model.pyx index 442233f19..7ff9517a5 100644 --- a/spacy/syntax/_parser_model.pyx +++ b/spacy/syntax/_parser_model.pyx @@ -221,7 +221,10 @@ cdef int arg_max_if_valid(const weight_t* scores, const int* is_valid, int n) no class ParserModel(Model): def __init__(self, tok2vec, lower_model, upper_model, unseen_classes=None): - Model.__init__(self, name="parser_model", forward=forward) + # don't define nO for this object, because we can't dynamically change it + Model.__init__(self, name="parser_model", forward=forward, dims={"nI": None}) + if tok2vec.has_dim("nI"): + self.set_dim("nI", tok2vec.get_dim("nI")) self._layers = [tok2vec, lower_model] if upper_model is not None: self._layers.append(upper_model) @@ -229,6 +232,7 @@ class ParserModel(Model): if unseen_classes: for class_ in unseen_classes: self.unseen_classes.add(class_) + self.set_ref("tok2vec", tok2vec) def predict(self, docs): step_model = ParserStepModel(docs, self._layers, @@ -238,25 +242,32 @@ class ParserModel(Model): def resize_output(self, new_nO): if len(self._layers) == 2: return - if new_nO == self.upper.get_dim("nO"): + if self.upper.has_dim("nO") and (new_nO == self.upper.get_dim("nO")): return smaller = self.upper - nI = smaller.get_dim("nI") + nI = None + if smaller.has_dim("nI"): + nI = smaller.get_dim("nI") with use_ops('numpy'): - larger = Linear(new_nO, nI) - larger_W = larger.ops.alloc2f(new_nO, nI) - larger_b = larger.ops.alloc1f(new_nO) - smaller_W = smaller.get_param("W") - smaller_b = smaller.get_param("b") - # Weights are stored in (nr_out, nr_in) format, so we're basically - # just adding rows here. - larger_W[:smaller.get_dim("nO")] = smaller_W - larger_b[:smaller.get_dim("nO")] = smaller_b - larger.set_param("W", larger_W) - larger.set_param("b", larger_b) + larger = Linear(nO=new_nO, nI=nI) + larger._init = smaller._init + # it could be that the model is not initialized yet, then skip this bit + if nI: + larger_W = larger.ops.alloc2f(new_nO, nI) + larger_b = larger.ops.alloc1f(new_nO) + smaller_W = smaller.get_param("W") + smaller_b = smaller.get_param("b") + # Weights are stored in (nr_out, nr_in) format, so we're basically + # just adding rows here. + if smaller.has_dim("nO"): + larger_W[:smaller.get_dim("nO")] = smaller_W + larger_b[:smaller.get_dim("nO")] = smaller_b + for i in range(smaller.get_dim("nO"), new_nO): + self.unseen_classes.add(i) + + larger.set_param("W", larger_W) + larger.set_param("b", larger_b) self._layers[-1] = larger - for i in range(smaller.get_dim("nO"), new_nO): - self.unseen_classes.add(i) def initialize(self, X=None, Y=None): self.tok2vec.initialize() @@ -412,7 +423,7 @@ cdef class precompute_hiddens: we can do all our hard maths up front, packed into large multiplications, and do the hard-to-program parsing on the CPU. """ - cdef readonly int nF, nO, nP # TODO: make these more like the dimensions in thinc + cdef readonly int nF, nO, nP cdef bint _is_synchronized cdef public object ops cdef np.ndarray _features @@ -458,6 +469,16 @@ cdef class precompute_hiddens: self._is_synchronized = True return self._cached.data + def has_dim(self, name): + if name == "nF": + return self.nF if self.nF is not None else True + elif name == "nP": + return self.nP if self.nP is not None else True + elif name == "nO": + return self.nO if self.nO is not None else True + else: + return False + def get_dim(self, name): if name == "nF": return self.nF @@ -468,6 +489,16 @@ cdef class precompute_hiddens: else: raise ValueError(f"Dimension {name} invalid -- only nO, nF, nP") + def set_dim(self, name, value): + if name == "nF": + self.nF = value + elif name == "nP": + self.nP = value + elif name == "nO": + self.nO = value + else: + raise ValueError(f"Dimension {name} invalid -- only nO, nF, nP") + def __call__(self, X, bint is_train): if is_train: return self.begin_update(X) diff --git a/spacy/syntax/nn_parser.pyx b/spacy/syntax/nn_parser.pyx index cf57e1cf6..9381fab6b 100644 --- a/spacy/syntax/nn_parser.pyx +++ b/spacy/syntax/nn_parser.pyx @@ -27,11 +27,11 @@ from ._parser_model cimport predict_states, arg_max_if_valid from ._parser_model cimport WeightsC, ActivationsC, SizesC, cpu_log_loss from ._parser_model cimport get_c_weights, get_c_sizes from ._parser_model import ParserModel -from ..util import link_vectors_to_models, create_default_optimizer +from ..util import link_vectors_to_models, create_default_optimizer, registry from ..compat import copy_array from ..tokens.doc cimport Doc from ..gold cimport GoldParse -from ..errors import Errors, TempErrors +from ..errors import Errors, user_warning, Warnings from .. import util from .stateclass cimport StateClass from ._state cimport StateC @@ -41,114 +41,42 @@ from . import _beam_utils from . import nonproj -from ..ml._layers import PrecomputableAffine -from ..ml.component_models import Tok2Vec - - cdef class Parser: """ Base class of the DependencyParser and EntityRecognizer. """ - @classmethod - def Model(cls, nr_class, **cfg): - depth = util.env_opt('parser_hidden_depth', cfg.get('hidden_depth', 1)) - subword_features = util.env_opt('subword_features', - cfg.get('subword_features', True)) - conv_depth = util.env_opt('conv_depth', cfg.get('conv_depth', 4)) - conv_window = util.env_opt('conv_window', cfg.get('conv_window', 1)) - t2v_pieces = util.env_opt('cnn_maxout_pieces', cfg.get('cnn_maxout_pieces', 3)) - bilstm_depth = util.env_opt('bilstm_depth', cfg.get('bilstm_depth', 0)) - self_attn_depth = util.env_opt('self_attn_depth', cfg.get('self_attn_depth', 0)) - nr_feature_tokens = cfg.get("nr_feature_tokens", cls.nr_feature) - if depth not in (0, 1): - raise ValueError(TempErrors.T004.format(value=depth)) - parser_maxout_pieces = util.env_opt('parser_maxout_pieces', - cfg.get('maxout_pieces', 2)) - token_vector_width = util.env_opt('token_vector_width', - cfg.get('token_vector_width', 96)) - hidden_width = util.env_opt('hidden_width', cfg.get('hidden_width', 64)) - if depth == 0: - hidden_width = nr_class - parser_maxout_pieces = 1 - embed_size = util.env_opt('embed_size', cfg.get('embed_size', 2000)) - pretrained_vectors = cfg.get('pretrained_vectors', None) - tok2vec = Tok2Vec(width=token_vector_width, - embed_size=embed_size, - conv_depth=conv_depth, - window_size=conv_window, - cnn_maxout_pieces=t2v_pieces, - subword_features=subword_features, - pretrained_vectors=pretrained_vectors, - bilstm_depth=bilstm_depth) - tok2vec = chain(tok2vec, list2array()) - tok2vec.set_dim("nO", token_vector_width) - lower = PrecomputableAffine(hidden_width, - nF=nr_feature_tokens, nI=token_vector_width, - nP=parser_maxout_pieces) - lower.set_dim("nP", parser_maxout_pieces) - if depth == 1: - with use_ops('numpy'): - upper = Linear(nr_class, hidden_width, init_W=zero_init) - else: - upper = None - - cfg = { - 'nr_class': nr_class, - 'nr_feature_tokens': nr_feature_tokens, - 'hidden_depth': depth, - 'token_vector_width': token_vector_width, - 'hidden_width': hidden_width, - 'maxout_pieces': parser_maxout_pieces, - 'pretrained_vectors': pretrained_vectors, - 'bilstm_depth': bilstm_depth, - 'self_attn_depth': self_attn_depth, - 'conv_depth': conv_depth, - 'window_size': conv_window, - 'embed_size': embed_size, - 'cnn_maxout_pieces': t2v_pieces - } - model = ParserModel(tok2vec, lower, upper) - model.initialize() - return model, cfg - name = 'base_parser' - def __init__(self, Vocab vocab, moves=True, model=True, **cfg): + + def __init__(self, Vocab vocab, model, **cfg): """Create a Parser. vocab (Vocab): The vocabulary object. Must be shared with documents to be processed. The value is set to the `.vocab` attribute. - moves (TransitionSystem): Defines how the parse-state is created, - updated and evaluated. The value is set to the .moves attribute - unless True (default), in which case a new instance is created with - `Parser.Moves()`. - model (object): Defines how the parse-state is created, updated and - evaluated. The value is set to the .model attribute. If set to True - (default), a new instance will be created with `Parser.Model()` - in parser.begin_training(), parser.from_disk() or parser.from_bytes(). - **cfg: Arbitrary configuration parameters. Set to the `.cfg` attribute + **cfg: Configuration parameters. Set to the `.cfg` attribute. + If it doesn't include a value for 'moves', a new instance is + created with `self.TransitionSystem()`. This defines how the + parse-state is created, updated and evaluated. """ self.vocab = vocab - if moves is True: - self.moves = self.TransitionSystem(self.vocab.strings) - else: - self.moves = moves - if 'beam_width' not in cfg: - cfg['beam_width'] = util.env_opt('beam_width', 1) - if 'beam_density' not in cfg: - cfg['beam_density'] = util.env_opt('beam_density', 0.0) - if 'beam_update_prob' not in cfg: - cfg['beam_update_prob'] = util.env_opt('beam_update_prob', 1.0) - cfg.setdefault('cnn_maxout_pieces', 3) - cfg.setdefault("nr_feature_tokens", self.nr_feature) - self.cfg = cfg + moves = cfg.get("moves", None) + if moves is None: + # defined by EntityRecognizer as a BiluoPushDown + moves = self.TransitionSystem(self.vocab.strings) + self.moves = moves + cfg.setdefault('min_action_freq', 30) + cfg.setdefault('learn_tokens', False) + cfg.setdefault('beam_width', 1) + cfg.setdefault('beam_update_prob', 1.0) # or 0.5 (both defaults were previously used) self.model = model + self.set_output(self.moves.n_moves) + self.cfg = cfg self._multitasks = [] self._rehearsal_model = None @classmethod - def from_nlp(cls, nlp, **cfg): - return cls(nlp.vocab, **cfg) + def from_nlp(cls, nlp, model, **cfg): + return cls(nlp.vocab, model, **cfg) def __reduce__(self): return (Parser, (self.vocab, self.moves, self.model), None, None) @@ -163,8 +91,6 @@ cdef class Parser: names.append(name) return names - nr_feature = 8 - @property def labels(self): class_names = [self.moves.get_class_name(i) for i in range(self.moves.n_moves)] @@ -173,7 +99,7 @@ cdef class Parser: @property def tok2vec(self): '''Return the embedding and convolutional layer of the model.''' - return None if self.model in (None, True, False) else self.model.tok2vec + return self.model.tok2vec @property def postprocesses(self): @@ -190,10 +116,7 @@ cdef class Parser: self._resize() def _resize(self): - if "nr_class" in self.cfg: - self.cfg["nr_class"] = self.moves.n_moves - if self.model not in (True, False, None): - self.model.resize_output(self.moves.n_moves) + self.model.resize_output(self.moves.n_moves) if self._rehearsal_model not in (True, False, None): self._rehearsal_model.resize_output(self.moves.n_moves) @@ -227,7 +150,7 @@ cdef class Parser: doc (Doc): The document to be processed. """ if beam_width is None: - beam_width = self.cfg.get('beam_width', 1) + beam_width = self.cfg['beam_width'] beam_density = self.cfg.get('beam_density', 0.) states = self.predict([doc], beam_width=beam_width, beam_density=beam_density) @@ -243,7 +166,7 @@ cdef class Parser: YIELDS (Doc): Documents, in order. """ if beam_width is None: - beam_width = self.cfg.get('beam_width', 1) + beam_width = self.cfg['beam_width'] beam_density = self.cfg.get('beam_density', 0.) cdef Doc doc for batch in util.minibatch(docs, size=batch_size): @@ -264,13 +187,7 @@ cdef class Parser: else: yield from batch_in_order - def require_model(self): - """Raise an error if the component's model is not initialized.""" - if getattr(self, 'model', None) in (None, True, False): - raise ValueError(Errors.E109.format(name=self.name)) - def predict(self, docs, beam_width=1, beam_density=0.0, drop=0.): - self.require_model() if isinstance(docs, Doc): docs = [docs] if not any(len(doc) for doc in docs): @@ -313,11 +230,11 @@ cdef class Parser: # if labels are missing. We therefore have to check whether we need to # expand our model output. self._resize() + cdef int nr_feature = self.model.lower.get_dim("nF") model = self.model.predict(docs) - token_ids = numpy.zeros((len(docs) * beam_width, self.nr_feature), + token_ids = numpy.zeros((len(docs) * beam_width, nr_feature), dtype='i', order='C') cdef int* c_ids - cdef int nr_feature = self.cfg["nr_feature_tokens"] cdef int n_states model = self.model.predict(docs) todo = [beam for beam in beams if not beam.is_done] @@ -430,7 +347,6 @@ cdef class Parser: return [b for b in beams if not b.is_done] def update(self, examples, drop=0., set_annotations=False, sgd=None, losses=None): - self.require_model() examples = Example.to_example_objects(examples) if losses is None: @@ -440,9 +356,9 @@ cdef class Parser: multitask.update(examples, drop=drop, sgd=sgd) # The probability we use beam update, instead of falling back to # a greedy update - beam_update_prob = self.cfg.get('beam_update_prob', 0.5) - if self.cfg.get('beam_width', 1) >= 2 and numpy.random.random() < beam_update_prob: - return self.update_beam(examples, self.cfg.get('beam_width', 1), + beam_update_prob = self.cfg['beam_update_prob'] + if self.cfg['beam_width'] >= 2 and numpy.random.random() < beam_update_prob: + return self.update_beam(examples, self.cfg['beam_width'], drop=drop, sgd=sgd, losses=losses, set_annotations=set_annotations, beam_density=self.cfg.get('beam_density', 0.001)) @@ -533,7 +449,7 @@ cdef class Parser: set_dropout_rate(self.model, drop) model, backprop_tok2vec = self.model.begin_update(docs) states_d_scores, backprops, beams = _beam_utils.update_beam( - self.moves, self.cfg["nr_feature_tokens"], 10000, states, golds, + self.moves, self.model.lower.get_dim("nF"), 10000, states, golds, model.state2vec, model.vec2scores, width, losses=losses, beam_density=beam_density) for i, d_scores in enumerate(states_d_scores): @@ -562,8 +478,6 @@ cdef class Parser: keyed by the parameter ID. The values are (weights, gradients) tuples. """ gradients = {} - if self.model in (None, True, False): - return gradients queue = [self.model] seen = set() for node in queue: @@ -647,45 +561,40 @@ cdef class Parser: def create_optimizer(self): return create_default_optimizer() - def begin_training(self, get_examples, pipeline=None, sgd=None, **cfg): - if 'model' in cfg: - self.model = cfg['model'] + def set_output(self, nO): + if self.model.upper.has_dim("nO") is None: + self.model.upper.set_dim("nO", nO) + + def begin_training(self, get_examples, pipeline=None, sgd=None, **kwargs): + self.cfg.update(kwargs) if not hasattr(get_examples, '__call__'): gold_tuples = get_examples get_examples = lambda: gold_tuples - cfg.setdefault('min_action_freq', 30) actions = self.moves.get_actions(gold_parses=get_examples(), - min_freq=cfg.get('min_action_freq', 30), - learn_tokens=self.cfg.get("learn_tokens", False)) + min_freq=self.cfg['min_action_freq'], + learn_tokens=self.cfg["learn_tokens"]) for action, labels in self.moves.labels.items(): actions.setdefault(action, {}) for label, freq in labels.items(): if label not in actions[action]: actions[action][label] = freq self.moves.initialize_actions(actions) - cfg.setdefault('token_vector_width', 96) - if self.model is True: - self.model, cfg = self.Model(self.moves.n_moves, **cfg) - if sgd is None: - sgd = self.create_optimizer() - doc_sample = [] - gold_sample = [] - for example in islice(get_examples(), 1000): - parses = example.get_gold_parses(merge=False, vocab=self.vocab) - for doc, gold in parses: - doc_sample.append(doc) - gold_sample.append(gold) - self.model.initialize(doc_sample, gold_sample) - if pipeline is not None: - self.init_multitask_objectives(get_examples, pipeline, sgd=sgd, **cfg) - link_vectors_to_models(self.vocab) - else: - if sgd is None: - sgd = self.create_optimizer() - if self.model.upper.has_dim("nO") is None: - self.model.upper.set_dim("nO", self.moves.n_moves) - self.model.initialize() - self.cfg.update(cfg) + # make sure we resize so we have an appropriate upper layer + self._resize() + if sgd is None: + sgd = self.create_optimizer() + doc_sample = [] + gold_sample = [] + for example in islice(get_examples(), 1000): + parses = example.get_gold_parses(merge=False, vocab=self.vocab) + for doc, gold in parses: + doc_sample.append(doc) + gold_sample.append(gold) + + self.model.initialize(doc_sample, gold_sample) + if pipeline is not None: + self.init_multitask_objectives(get_examples, pipeline, sgd=sgd, **self.cfg) + link_vectors_to_models(self.vocab) return sgd def _get_doc(self, example): @@ -709,28 +618,24 @@ cdef class Parser: 'vocab': lambda p: self.vocab.from_disk(p), 'moves': lambda p: self.moves.from_disk(p, exclude=["strings"]), 'cfg': lambda p: self.cfg.update(srsly.read_json(p)), - 'model': lambda p: None + 'model': lambda p: None, } exclude = util.get_serialization_exclude(deserializers, exclude, kwargs) util.from_disk(path, deserializers, exclude) if 'model' not in exclude: path = util.ensure_path(path) - if self.model is True: - self.model, cfg = self.Model(**self.cfg) - else: - cfg = {} with (path / 'model').open('rb') as file_: bytes_data = file_.read() try: + self._resize() self.model.from_bytes(bytes_data) except AttributeError: raise ValueError(Errors.E149) - self.cfg.update(cfg) return self def to_bytes(self, exclude=tuple(), **kwargs): serializers = { - "model": lambda: (self.model.to_bytes() if self.model is not True else True), + "model": lambda: (self.model.to_bytes()), "vocab": lambda: self.vocab.to_bytes(), "moves": lambda: self.moves.to_bytes(exclude=["strings"]), "cfg": lambda: srsly.json_dumps(self.cfg, indent=2, sort_keys=True) @@ -743,22 +648,14 @@ cdef class Parser: "vocab": lambda b: self.vocab.from_bytes(b), "moves": lambda b: self.moves.from_bytes(b, exclude=["strings"]), "cfg": lambda b: self.cfg.update(srsly.json_loads(b)), - "model": lambda b: None + "model": lambda b: None, } exclude = util.get_serialization_exclude(deserializers, exclude, kwargs) msg = util.from_bytes(bytes_data, deserializers, exclude) if 'model' not in exclude: - # TODO: Remove this once we don't have to handle previous models - if self.cfg.get('pretrained_dims') and 'pretrained_vectors' not in self.cfg: - self.cfg['pretrained_vectors'] = self.vocab.vectors - if self.model is True: - self.model, cfg = self.Model(**self.cfg) - else: - cfg = {} if 'model' in msg: try: self.model.from_bytes(msg['model']) except AttributeError: raise ValueError(Errors.E149) - self.cfg.update(cfg) return self diff --git a/spacy/tests/doc/test_add_entities.py b/spacy/tests/doc/test_add_entities.py index 766dcb739..3a466b24c 100644 --- a/spacy/tests/doc/test_add_entities.py +++ b/spacy/tests/doc/test_add_entities.py @@ -3,12 +3,13 @@ from spacy.tokens import Span import pytest from ..util import get_doc +from ...ml.models.defaults import default_ner def test_doc_add_entities_set_ents_iob(en_vocab): text = ["This", "is", "a", "lion"] doc = get_doc(en_vocab, text) - ner = EntityRecognizer(en_vocab) + ner = EntityRecognizer(en_vocab, default_ner()) ner.begin_training([]) ner(doc) assert len(list(doc.ents)) == 0 @@ -24,7 +25,7 @@ def test_doc_add_entities_set_ents_iob(en_vocab): def test_ents_reset(en_vocab): text = ["This", "is", "a", "lion"] doc = get_doc(en_vocab, text) - ner = EntityRecognizer(en_vocab) + ner = EntityRecognizer(en_vocab, default_ner()) ner.begin_training([]) ner(doc) assert [t.ent_iob_ for t in doc] == (["O"] * len(doc)) diff --git a/spacy/tests/parser/test_add_label.py b/spacy/tests/parser/test_add_label.py index fe847a6ae..5af772ddc 100644 --- a/spacy/tests/parser/test_add_label.py +++ b/spacy/tests/parser/test_add_label.py @@ -3,6 +3,8 @@ from thinc.api import Adam, NumpyOps from spacy.attrs import NORM from spacy.gold import GoldParse from spacy.vocab import Vocab + +from spacy.ml.models.defaults import default_parser, default_ner from spacy.tokens import Doc from spacy.pipeline import DependencyParser, EntityRecognizer from spacy.util import fix_random_seed @@ -15,7 +17,7 @@ def vocab(): @pytest.fixture def parser(vocab): - parser = DependencyParser(vocab) + parser = DependencyParser(vocab, default_parser()) return parser @@ -55,27 +57,31 @@ def test_add_label(parser): def test_add_label_deserializes_correctly(): - ner1 = EntityRecognizer(Vocab()) + ner1 = EntityRecognizer(Vocab(), default_ner()) ner1.add_label("C") ner1.add_label("B") ner1.add_label("A") ner1.begin_training([]) - ner2 = EntityRecognizer(Vocab()).from_bytes(ner1.to_bytes()) + ner2 = EntityRecognizer(Vocab(), default_ner()) + + # the second model needs to be resized before we can call from_bytes + ner2.model.resize_output(ner1.moves.n_moves) + ner2.from_bytes(ner1.to_bytes()) assert ner1.moves.n_moves == ner2.moves.n_moves for i in range(ner1.moves.n_moves): assert ner1.moves.get_class_name(i) == ner2.moves.get_class_name(i) @pytest.mark.parametrize( - "pipe_cls,n_moves", [(DependencyParser, 5), (EntityRecognizer, 4)] + "pipe_cls,n_moves,model", [(DependencyParser, 5, default_parser()), (EntityRecognizer, 4, default_ner())] ) -def test_add_label_get_label(pipe_cls, n_moves): +def test_add_label_get_label(pipe_cls, n_moves, model): """Test that added labels are returned correctly. This test was added to test for a bug in DependencyParser.labels that'd cause it to fail when splitting the move names. """ labels = ["A", "B", "C"] - pipe = pipe_cls(Vocab()) + pipe = pipe_cls(Vocab(), model) for label in labels: pipe.add_label(label) assert len(pipe.move_names) == len(labels) * n_moves diff --git a/spacy/tests/parser/test_arc_eager_oracle.py b/spacy/tests/parser/test_arc_eager_oracle.py index dd593f7d3..2426805d2 100644 --- a/spacy/tests/parser/test_arc_eager_oracle.py +++ b/spacy/tests/parser/test_arc_eager_oracle.py @@ -1,5 +1,7 @@ import pytest from spacy.vocab import Vocab + +from spacy.ml.models.defaults import default_parser from spacy.pipeline import DependencyParser from spacy.tokens import Doc from spacy.gold import GoldParse @@ -136,7 +138,7 @@ def test_get_oracle_actions(): deps.append(dep) ents.append(ent) doc = Doc(Vocab(), words=[t[1] for t in annot_tuples]) - parser = DependencyParser(doc.vocab) + parser = DependencyParser(doc.vocab, default_parser()) parser.moves.add_action(0, "") parser.moves.add_action(1, "") parser.moves.add_action(1, "") diff --git a/spacy/tests/parser/test_ner.py b/spacy/tests/parser/test_ner.py index 9a4d21a8d..3fde75eb5 100644 --- a/spacy/tests/parser/test_ner.py +++ b/spacy/tests/parser/test_ner.py @@ -1,10 +1,15 @@ import pytest + +from spacy import util from spacy.lang.en import English +from spacy.ml.models.defaults import default_ner from spacy.pipeline import EntityRecognizer, EntityRuler from spacy.vocab import Vocab from spacy.syntax.ner import BiluoPushDown from spacy.gold import GoldParse + +from spacy.tests.util import make_tempdir from spacy.tokens import Doc TRAIN_DATA = [ @@ -134,7 +139,7 @@ def test_accept_blocked_token(): # 1. test normal behaviour nlp1 = English() doc1 = nlp1("I live in New York") - ner1 = EntityRecognizer(doc1.vocab) + ner1 = EntityRecognizer(doc1.vocab, default_ner()) assert [token.ent_iob_ for token in doc1] == ["", "", "", "", ""] assert [token.ent_type_ for token in doc1] == ["", "", "", "", ""] @@ -152,7 +157,7 @@ def test_accept_blocked_token(): # 2. test blocking behaviour nlp2 = English() doc2 = nlp2("I live in New York") - ner2 = EntityRecognizer(doc2.vocab) + ner2 = EntityRecognizer(doc2.vocab, default_ner()) # set "New York" to a blocked entity doc2.ents = [(0, 3, 5)] @@ -188,7 +193,7 @@ def test_overwrite_token(): assert [token.ent_type_ for token in doc] == ["", "", "", "", ""] # Check that a new ner can overwrite O - ner2 = EntityRecognizer(doc.vocab) + ner2 = EntityRecognizer(doc.vocab, default_ner()) ner2.moves.add_action(5, "") ner2.add_label("GPE") state = ner2.moves.init_batch([doc])[0] @@ -199,6 +204,17 @@ def test_overwrite_token(): assert ner2.moves.is_valid(state, "L-GPE") +def test_empty_ner(): + nlp = English() + ner = nlp.create_pipe("ner") + ner.add_label("MY_LABEL") + nlp.add_pipe(ner) + nlp.begin_training() + doc = nlp("John is watching the news about Croatia's elections") + # if this goes wrong, the initialization of the parser's upper layer is probably broken + assert [token.ent_iob_ for token in doc] == ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O'] + + def test_ruler_before_ner(): """ Test that an NER works after an entity_ruler: the second can add annotations """ nlp = English() @@ -214,7 +230,6 @@ def test_ruler_before_ner(): untrained_ner.add_label("MY_LABEL") nlp.add_pipe(untrained_ner) nlp.begin_training() - doc = nlp("This is Antti Korhonen speaking in Finland") expected_iobs = ["B", "O", "O", "O", "O", "O", "O"] expected_types = ["THING", "", "", "", "", "", ""] @@ -261,28 +276,7 @@ def test_block_ner(): assert [token.ent_type_ for token in doc] == expected_types -def test_change_number_features(): - # Test the default number features - nlp = English() - ner = nlp.create_pipe("ner") - nlp.add_pipe(ner) - ner.add_label("PERSON") - nlp.begin_training() - assert ner.model.lower.get_dim("nF") == ner.nr_feature - # Test we can change it - nlp = English() - ner = nlp.create_pipe("ner") - nlp.add_pipe(ner) - ner.add_label("PERSON") - nlp.begin_training( - component_cfg={"ner": {"nr_feature_tokens": 3, "token_vector_width": 128}} - ) - assert ner.model.lower.get_dim("nF") == 3 - # Test the model runs - nlp("hello world") - - -def test_overfitting(): +def test_overfitting_IO(): # Simple test to try and quickly overfit the NER component - ensuring the ML models work correctly nlp = English() ner = nlp.create_pipe("ner") @@ -301,11 +295,20 @@ def test_overfitting(): test_text = "I like London." doc = nlp(test_text) ents = doc.ents - assert len(ents) == 1 assert ents[0].text == "London" assert ents[0].label_ == "LOC" + # Also test the results are still the same after IO + with make_tempdir() as tmp_dir: + nlp.to_disk(tmp_dir) + nlp2 = util.load_model_from_path(tmp_dir) + doc2 = nlp2(test_text) + ents2 = doc2.ents + assert len(ents2) == 1 + assert ents2[0].text == "London" + assert ents2[0].label_ == "LOC" + class BlockerComponent1(object): name = "my_blocker" diff --git a/spacy/tests/parser/test_neural_parser.py b/spacy/tests/parser/test_neural_parser.py index 2470982d3..984af4d6b 100644 --- a/spacy/tests/parser/test_neural_parser.py +++ b/spacy/tests/parser/test_neural_parser.py @@ -1,8 +1,9 @@ import pytest -from spacy.ml.component_models import Tok2Vec +from spacy.ml.models.defaults import default_parser, default_tok2vec from spacy.vocab import Vocab from spacy.syntax.arc_eager import ArcEager from spacy.syntax.nn_parser import Parser +from spacy.syntax._parser_model import ParserModel from spacy.tokens.doc import Doc from spacy.gold import GoldParse @@ -20,19 +21,22 @@ def arc_eager(vocab): @pytest.fixture def tok2vec(): - tok2vec = Tok2Vec(8, 100) + tok2vec = default_tok2vec() tok2vec.initialize() return tok2vec @pytest.fixture def parser(vocab, arc_eager): - return Parser(vocab, moves=arc_eager, model=None) + return Parser(vocab, model=default_parser(), moves=arc_eager) @pytest.fixture -def model(arc_eager, tok2vec): - return Parser.Model(arc_eager.n_moves, token_vector_width=tok2vec.get_dim("nO"))[0] +def model(arc_eager, tok2vec, vocab): + model = default_parser() + model.resize_output(arc_eager.n_moves) + model.initialize() + return model @pytest.fixture @@ -46,11 +50,11 @@ def gold(doc): def test_can_init_nn_parser(parser): - assert parser.model is None + assert isinstance(parser.model, ParserModel) -def test_build_model(parser): - parser.model = Parser.Model(parser.moves.n_moves, hist_size=0)[0] +def test_build_model(parser, vocab): + parser.model = Parser(vocab, model=default_parser(), moves=parser.moves).model assert parser.model is not None diff --git a/spacy/tests/parser/test_nn_beam.py b/spacy/tests/parser/test_nn_beam.py index 24997e47c..619e0cc0b 100644 --- a/spacy/tests/parser/test_nn_beam.py +++ b/spacy/tests/parser/test_nn_beam.py @@ -2,6 +2,7 @@ import pytest import numpy from spacy.vocab import Vocab from spacy.language import Language +from spacy.ml.models.defaults import default_parser from spacy.pipeline import DependencyParser from spacy.syntax.arc_eager import ArcEager from spacy.tokens import Doc @@ -93,7 +94,7 @@ def test_beam_advance_too_few_scores(beam, scores): def test_beam_parse(): nlp = Language() - nlp.add_pipe(DependencyParser(nlp.vocab), name="parser") + nlp.add_pipe(DependencyParser(nlp.vocab, default_parser()), name="parser") nlp.parser.add_label("nsubj") nlp.parser.begin_training([], token_vector_width=8, hidden_width=8) doc = nlp.make_doc("Australia is a country") diff --git a/spacy/tests/parser/test_parse.py b/spacy/tests/parser/test_parse.py index 1d3f522c9..6e13d3044 100644 --- a/spacy/tests/parser/test_parse.py +++ b/spacy/tests/parser/test_parse.py @@ -1,7 +1,8 @@ import pytest from spacy.lang.en import English -from ..util import get_doc, apply_transition_sequence +from ..util import get_doc, apply_transition_sequence, make_tempdir +from ... import util TRAIN_DATA = [ ( @@ -182,7 +183,7 @@ def test_parser_set_sent_starts(en_vocab): assert token.head in sent -def test_overfitting(): +def test_overfitting_IO(): # Simple test to try and quickly overfit the dependency parser - ensuring the ML models work correctly nlp = English() parser = nlp.create_pipe("parser") @@ -200,7 +201,15 @@ def test_overfitting(): # test the trained model test_text = "I like securities." doc = nlp(test_text) - assert doc[0].dep_ is "nsubj" assert doc[2].dep_ is "dobj" assert doc[3].dep_ is "punct" + + # Also test the results are still the same after IO + with make_tempdir() as tmp_dir: + nlp.to_disk(tmp_dir) + nlp2 = util.load_model_from_path(tmp_dir) + doc2 = nlp2(test_text) + assert doc2[0].dep_ is "nsubj" + assert doc2[2].dep_ is "dobj" + assert doc2[3].dep_ is "punct" diff --git a/spacy/tests/parser/test_preset_sbd.py b/spacy/tests/parser/test_preset_sbd.py index c6c1240a8..af777aa6b 100644 --- a/spacy/tests/parser/test_preset_sbd.py +++ b/spacy/tests/parser/test_preset_sbd.py @@ -3,6 +3,8 @@ from thinc.api import Adam from spacy.attrs import NORM from spacy.gold import GoldParse from spacy.vocab import Vocab + +from spacy.ml.models.defaults import default_parser from spacy.tokens import Doc from spacy.pipeline import DependencyParser @@ -14,7 +16,7 @@ def vocab(): @pytest.fixture def parser(vocab): - parser = DependencyParser(vocab) + parser = DependencyParser(vocab, default_parser()) parser.cfg["token_vector_width"] = 4 parser.cfg["hidden_width"] = 32 # parser.add_label('right') diff --git a/spacy/tests/pipeline/test_analysis.py b/spacy/tests/pipeline/test_analysis.py index 5c246538c..cda39f6ee 100644 --- a/spacy/tests/pipeline/test_analysis.py +++ b/spacy/tests/pipeline/test_analysis.py @@ -111,7 +111,8 @@ def test_component_factories_from_nlp(): nlp.add_pipe(pipe) assert nlp("hello world") # The first argument here is the class itself, so we're accepting any here - mock.assert_called_once_with(ANY, nlp, foo="bar") + # The model will be initialized to None by the factory + mock.assert_called_once_with(ANY, nlp, None, foo="bar") def test_analysis_validate_attrs_valid(): diff --git a/spacy/tests/pipeline/test_tagger.py b/spacy/tests/pipeline/test_tagger.py index 366cd4f1a..a90207a78 100644 --- a/spacy/tests/pipeline/test_tagger.py +++ b/spacy/tests/pipeline/test_tagger.py @@ -1,5 +1,9 @@ import pytest + +from spacy import util +from spacy.lang.en import English from spacy.language import Language +from spacy.tests.util import make_tempdir def test_label_types(): @@ -18,9 +22,9 @@ TRAIN_DATA = [ ] -def test_overfitting(): +def test_overfitting_IO(): # Simple test to try and quickly overfit the tagger - ensuring the ML models work correctly - nlp = Language() + nlp = English() tagger = nlp.create_pipe("tagger") for tag, values in TAG_MAP.items(): tagger.add_label(tag, values) @@ -35,8 +39,17 @@ def test_overfitting(): # test the trained model test_text = "I like blue eggs" doc = nlp(test_text) - assert doc[0].tag_ is "N" assert doc[1].tag_ is "V" assert doc[2].tag_ is "J" assert doc[3].tag_ is "N" + + # Also test the results are still the same after IO + with make_tempdir() as tmp_dir: + nlp.to_disk(tmp_dir) + nlp2 = util.load_model_from_path(tmp_dir) + doc2 = nlp2(test_text) + assert doc2[0].tag_ is "N" + assert doc2[1].tag_ is "V" + assert doc2[2].tag_ is "J" + assert doc2[3].tag_ is "N" diff --git a/spacy/tests/pipeline/test_textcat.py b/spacy/tests/pipeline/test_textcat.py index 558d09e40..1b5ca9a4c 100644 --- a/spacy/tests/pipeline/test_textcat.py +++ b/spacy/tests/pipeline/test_textcat.py @@ -1,8 +1,12 @@ import pytest import random import numpy.random + +from spacy import util +from spacy.lang.en import English from spacy.language import Language from spacy.pipeline import TextCategorizer +from spacy.tests.util import make_tempdir from spacy.tokens import Doc from spacy.gold import GoldParse @@ -74,9 +78,9 @@ def test_label_types(): nlp.get_pipe("textcat").add_label(9) -def test_overfitting(): +def test_overfitting_IO(): # Simple test to try and quickly overfit the textcat component - ensuring the ML models work correctly - nlp = Language() + nlp = English() textcat = nlp.create_pipe("textcat") for _, annotations in TRAIN_DATA: for label, value in annotations.get("cats").items(): @@ -87,11 +91,21 @@ def test_overfitting(): for i in range(50): losses = {} nlp.update(TRAIN_DATA, sgd=optimizer, losses=losses) - assert losses["textcat"] < 0.00001 + assert losses["textcat"] < 0.01 # test the trained model test_text = "I am happy." doc = nlp(test_text) cats = doc.cats + # note that by default, exclusive_classes = false so we need a bigger error margin assert cats["POSITIVE"] > 0.9 - assert cats["POSITIVE"] + cats["NEGATIVE"] == pytest.approx(1.0, 0.001) + assert cats["POSITIVE"] + cats["NEGATIVE"] == pytest.approx(1.0, 0.1) + + # Also test the results are still the same after IO + with make_tempdir() as tmp_dir: + nlp.to_disk(tmp_dir) + nlp2 = util.load_model_from_path(tmp_dir) + doc2 = nlp2(test_text) + cats2 = doc2.cats + assert cats2["POSITIVE"] > 0.9 + assert cats2["POSITIVE"] + cats2["NEGATIVE"] == pytest.approx(1.0, 0.1) diff --git a/spacy/tests/regression/test_issue1501-2000.py b/spacy/tests/regression/test_issue1501-2000.py index 2bfdbd7c3..ff8c7c2fe 100644 --- a/spacy/tests/regression/test_issue1501-2000.py +++ b/spacy/tests/regression/test_issue1501-2000.py @@ -10,6 +10,7 @@ from spacy.lang.lex_attrs import is_stop from spacy.vectors import Vectors from spacy.vocab import Vocab from spacy.language import Language +from spacy.ml.models.defaults import default_ner, default_tagger from spacy.tokens import Doc, Span, Token from spacy.pipeline import Tagger, EntityRecognizer from spacy.attrs import HEAD, DEP @@ -123,7 +124,7 @@ def test_issue1727(): correctly after vectors are added.""" data = numpy.ones((3, 300), dtype="f") vectors = Vectors(data=data, keys=["I", "am", "Matt"]) - tagger = Tagger(Vocab()) + tagger = Tagger(Vocab(), default_tagger()) tagger.add_label("PRP") with pytest.warns(UserWarning): tagger.begin_training() @@ -131,7 +132,7 @@ def test_issue1727(): tagger.vocab.vectors = vectors with make_tempdir() as path: tagger.to_disk(path) - tagger = Tagger(Vocab()).from_disk(path) + tagger = Tagger(Vocab(), default_tagger()).from_disk(path) assert tagger.cfg.get("pretrained_dims", 0) == 0 @@ -236,6 +237,7 @@ def test_issue1889(word): assert is_stop(word, STOP_WORDS) == is_stop(word.upper(), STOP_WORDS) +@pytest.mark.skip(reason="This test has become obsolete with the config refactor of v.3") def test_issue1915(): cfg = {"hidden_depth": 2} # should error out nlp = Language() @@ -268,7 +270,7 @@ def test_issue1963(en_tokenizer): @pytest.mark.parametrize("label", ["U-JOB-NAME"]) def test_issue1967(label): - ner = EntityRecognizer(Vocab()) + ner = EntityRecognizer(Vocab(), default_ner()) example = Example(doc=None) example.set_token_annotation( ids=[0], words=["word"], tags=["tag"], heads=[0], deps=["dep"], entities=[label] diff --git a/spacy/tests/regression/test_issue2001-2500.py b/spacy/tests/regression/test_issue2001-2500.py index 2c25b6d73..1786677e0 100644 --- a/spacy/tests/regression/test_issue2001-2500.py +++ b/spacy/tests/regression/test_issue2001-2500.py @@ -32,6 +32,9 @@ def test_issue2179(): nlp.begin_training() nlp2 = Italian() nlp2.add_pipe(nlp2.create_pipe("ner")) + + assert len(nlp2.get_pipe("ner").labels) == 0 + nlp2.get_pipe("ner").model.resize_output(nlp.get_pipe("ner").moves.n_moves) nlp2.from_bytes(nlp.to_bytes()) assert "extra_labels" not in nlp2.get_pipe("ner").cfg assert nlp2.get_pipe("ner").labels == ("CITIZENSHIP",) diff --git a/spacy/tests/regression/test_issue3001-3500.py b/spacy/tests/regression/test_issue3001-3500.py index cc893e472..df23efa4f 100644 --- a/spacy/tests/regression/test_issue3001-3500.py +++ b/spacy/tests/regression/test_issue3001-3500.py @@ -1,6 +1,7 @@ import pytest from spacy.lang.en import English from spacy.lang.de import German +from spacy.ml.models.defaults import default_ner from spacy.pipeline import EntityRuler, EntityRecognizer from spacy.matcher import Matcher, PhraseMatcher from spacy.tokens import Doc @@ -103,6 +104,7 @@ def test_issue3209(): assert ner.move_names == move_names nlp2 = English() nlp2.add_pipe(nlp2.create_pipe("ner")) + nlp2.get_pipe("ner").model.resize_output(ner.moves.n_moves) nlp2.from_bytes(nlp.to_bytes()) assert nlp2.get_pipe("ner").move_names == move_names @@ -193,7 +195,7 @@ def test_issue3345(): doc = Doc(nlp.vocab, words=["I", "live", "in", "New", "York"]) doc[4].is_sent_start = True ruler = EntityRuler(nlp, patterns=[{"label": "GPE", "pattern": "New York"}]) - ner = EntityRecognizer(doc.vocab) + ner = EntityRecognizer(doc.vocab, default_ner()) # Add the OUT action. I wouldn't have thought this would be necessary... ner.moves.add_action(5, "") ner.add_label("GPE") diff --git a/spacy/tests/regression/test_issue3830.py b/spacy/tests/regression/test_issue3830.py index 54ce10924..9752f70df 100644 --- a/spacy/tests/regression/test_issue3830.py +++ b/spacy/tests/regression/test_issue3830.py @@ -1,10 +1,12 @@ from spacy.pipeline.pipes import DependencyParser from spacy.vocab import Vocab +from spacy.ml.models.defaults import default_parser + def test_issue3830_no_subtok(): """Test that the parser doesn't have subtok label if not learn_tokens""" - parser = DependencyParser(Vocab()) + parser = DependencyParser(Vocab(), default_parser()) parser.add_label("nsubj") assert "subtok" not in parser.labels parser.begin_training(lambda: []) @@ -13,7 +15,7 @@ def test_issue3830_no_subtok(): def test_issue3830_with_subtok(): """Test that the parser does have subtok label if learn_tokens=True.""" - parser = DependencyParser(Vocab(), learn_tokens=True) + parser = DependencyParser(Vocab(), default_parser(), learn_tokens=True) parser.add_label("nsubj") assert "subtok" not in parser.labels parser.begin_training(lambda: []) diff --git a/spacy/tests/regression/test_issue4042.py b/spacy/tests/regression/test_issue4042.py index 6644a8eda..75a1c23b7 100644 --- a/spacy/tests/regression/test_issue4042.py +++ b/spacy/tests/regression/test_issue4042.py @@ -3,6 +3,7 @@ from spacy.pipeline import EntityRecognizer, EntityRuler from spacy.lang.en import English from spacy.tokens import Span from spacy.util import ensure_path +from spacy.ml.models.defaults import default_ner from ..util import make_tempdir @@ -73,6 +74,6 @@ def test_issue4042_bug2(): output_dir.mkdir() ner1.to_disk(output_dir) - ner2 = EntityRecognizer(vocab) + ner2 = EntityRecognizer(vocab, default_ner()) ner2.from_disk(output_dir) assert len(ner2.labels) == 2 diff --git a/spacy/tests/regression/test_issue4313.py b/spacy/tests/regression/test_issue4313.py index a3f6f69df..30688601f 100644 --- a/spacy/tests/regression/test_issue4313.py +++ b/spacy/tests/regression/test_issue4313.py @@ -1,5 +1,6 @@ from collections import defaultdict +from spacy.ml.models.defaults import default_ner from spacy.pipeline import EntityRecognizer from spacy.lang.en import English @@ -11,7 +12,7 @@ def test_issue4313(): beam_width = 16 beam_density = 0.0001 nlp = English() - ner = EntityRecognizer(nlp.vocab) + ner = EntityRecognizer(nlp.vocab, default_ner()) ner.add_label("SOME_LABEL") ner.begin_training([]) nlp.add_pipe(ner) diff --git a/spacy/tests/serialize/test_serialize_config.py b/spacy/tests/serialize/test_serialize_config.py new file mode 100644 index 000000000..c34d01547 --- /dev/null +++ b/spacy/tests/serialize/test_serialize_config.py @@ -0,0 +1,126 @@ +from thinc.api import Config + +import spacy +from spacy import util +from spacy.lang.en import English +from spacy.util import registry + +from ..util import make_tempdir +from ...ml.models import build_Tok2Vec_model, build_tb_parser_model + +nlp_config_string = """ +[nlp] +lang = "en" + +[nlp.pipeline.tok2vec] +factory = "tok2vec" + +[nlp.pipeline.tok2vec.model] +@architectures = "spacy.HashEmbedCNN.v1" +pretrained_vectors = null +width = 342 +depth = 4 +window_size = 1 +embed_size = 2000 +maxout_pieces = 3 +subword_features = true + +[nlp.pipeline.tagger] +factory = "tagger" + +[nlp.pipeline.tagger.model] +@architectures = "spacy.Tagger.v1" + +[nlp.pipeline.tagger.model.tok2vec] +@architectures = "spacy.Tok2VecTensors.v1" +width = ${nlp.pipeline.tok2vec.model:width} +""" + + +parser_config_string = """ +[model] +@architectures = "spacy.TransitionBasedParser.v1" +nr_feature_tokens = 99 +hidden_width = 66 +maxout_pieces = 2 + +[model.tok2vec] +@architectures = "spacy.HashEmbedCNN.v1" +pretrained_vectors = null +width = 333 +depth = 4 +embed_size = 5555 +window_size = 1 +maxout_pieces = 7 +subword_features = false +""" + + +@registry.architectures.register("my_test_parser") +def my_parser(): + tok2vec = build_Tok2Vec_model(width=321, embed_size=5432, pretrained_vectors=None, window_size=3, + maxout_pieces=4, subword_features=True, char_embed=True, nM=64, nC=8, + conv_depth=2, bilstm_depth=0) + parser = build_tb_parser_model(tok2vec=tok2vec, nr_feature_tokens=7, hidden_width=65, maxout_pieces=5) + return parser + + +def test_serialize_nlp(): + """ Create a custom nlp pipeline from config and ensure it serializes it correctly """ + nlp_config = Config().from_str(nlp_config_string) + nlp = util.load_model_from_config(nlp_config["nlp"]) + nlp.begin_training() + assert "tok2vec" in nlp.pipe_names + assert "tagger" in nlp.pipe_names + assert "parser" not in nlp.pipe_names + assert nlp.get_pipe("tagger").model.get_ref("tok2vec").get_dim("nO") == 342 + + with make_tempdir() as d: + nlp.to_disk(d) + nlp2 = spacy.load(d) + assert "tok2vec" in nlp2.pipe_names + assert "tagger" in nlp2.pipe_names + assert "parser" not in nlp2.pipe_names + assert nlp2.get_pipe("tagger").model.get_ref("tok2vec").get_dim("nO") == 342 + + +def test_serialize_custom_nlp(): + """ Create a custom nlp pipeline and ensure it serializes it correctly""" + nlp = English() + parser_cfg = dict() + parser_cfg["model"] = {'@architectures': "my_test_parser"} + parser = nlp.create_pipe("parser", parser_cfg) + nlp.add_pipe(parser) + nlp.begin_training() + + with make_tempdir() as d: + nlp.to_disk(d) + nlp2 = spacy.load(d) + model = nlp2.get_pipe("parser").model + tok2vec = model.get_ref("tok2vec") + upper = model.upper + + # check that we have the correct settings, not the default ones + assert tok2vec.get_dim("nO") == 321 + assert upper.get_dim("nI") == 65 + + +def test_serialize_parser(): + """ Create a non-default parser config to check nlp serializes it correctly """ + nlp = English() + model_config = Config().from_str(parser_config_string) + parser = nlp.create_pipe("parser", config=model_config) + parser.add_label("nsubj") + nlp.add_pipe(parser) + nlp.begin_training() + + with make_tempdir() as d: + nlp.to_disk(d) + nlp2 = spacy.load(d) + model = nlp2.get_pipe("parser").model + tok2vec = model.get_ref("tok2vec") + upper = model.upper + + # check that we have the correct settings, not the default ones + assert upper.get_dim("nI") == 66 + assert tok2vec.get_dim("nO") == 333 diff --git a/spacy/tests/serialize/test_serialize_language.py b/spacy/tests/serialize/test_serialize_language.py index 4089a0d07..0e3b7c59f 100644 --- a/spacy/tests/serialize/test_serialize_language.py +++ b/spacy/tests/serialize/test_serialize_language.py @@ -1,5 +1,6 @@ import pytest import re + from spacy.language import Language from spacy.tokenizer import Tokenizer @@ -56,7 +57,7 @@ def test_serialize_language_exclude(meta_data): nlp = Language(meta=meta_data) assert nlp.meta["name"] == name new_nlp = Language().from_bytes(nlp.to_bytes()) - assert nlp.meta["name"] == name + assert new_nlp.meta["name"] == name new_nlp = Language().from_bytes(nlp.to_bytes(), exclude=["meta"]) assert not new_nlp.meta["name"] == name new_nlp = Language().from_bytes(nlp.to_bytes(exclude=["meta"])) diff --git a/spacy/tests/serialize/test_serialize_pipeline.py b/spacy/tests/serialize/test_serialize_pipeline.py index 0ad9bc4d4..fe14fba10 100644 --- a/spacy/tests/serialize/test_serialize_pipeline.py +++ b/spacy/tests/serialize/test_serialize_pipeline.py @@ -1,6 +1,7 @@ import pytest from spacy.pipeline import Tagger, DependencyParser, EntityRecognizer from spacy.pipeline import Tensorizer, TextCategorizer, SentenceRecognizer +from spacy.ml.models.defaults import default_parser, default_tensorizer, default_tagger, default_textcat, default_sentrec from ..util import make_tempdir @@ -10,58 +11,58 @@ test_parsers = [DependencyParser, EntityRecognizer] @pytest.fixture def parser(en_vocab): - parser = DependencyParser(en_vocab) + parser = DependencyParser(en_vocab, default_parser()) parser.add_label("nsubj") - parser.model, cfg = parser.Model(parser.moves.n_moves) - parser.cfg.update(cfg) return parser @pytest.fixture def blank_parser(en_vocab): - parser = DependencyParser(en_vocab) + parser = DependencyParser(en_vocab, default_parser()) return parser @pytest.fixture def taggers(en_vocab): - tagger1 = Tagger(en_vocab) - tagger2 = Tagger(en_vocab) - tagger1.model = tagger1.Model(8) - tagger2.model = tagger1.model - return (tagger1, tagger2) + model = default_tagger() + tagger1 = Tagger(en_vocab, model) + tagger2 = Tagger(en_vocab, model) + return tagger1, tagger2 @pytest.mark.parametrize("Parser", test_parsers) def test_serialize_parser_roundtrip_bytes(en_vocab, Parser): - parser = Parser(en_vocab) - parser.model, _ = parser.Model(10) - new_parser = Parser(en_vocab) - new_parser.model, _ = new_parser.Model(10) + parser = Parser(en_vocab, default_parser()) + new_parser = Parser(en_vocab, default_parser()) new_parser = new_parser.from_bytes(parser.to_bytes(exclude=["vocab"])) - assert new_parser.to_bytes(exclude=["vocab"]) == parser.to_bytes(exclude=["vocab"]) + bytes_2 = new_parser.to_bytes(exclude=["vocab"]) + bytes_3 = parser.to_bytes(exclude=["vocab"]) + assert len(bytes_2) == len(bytes_3) + assert bytes_2 == bytes_3 @pytest.mark.parametrize("Parser", test_parsers) def test_serialize_parser_roundtrip_disk(en_vocab, Parser): - parser = Parser(en_vocab) - parser.model, _ = parser.Model(0) + parser = Parser(en_vocab, default_parser()) with make_tempdir() as d: file_path = d / "parser" parser.to_disk(file_path) - parser_d = Parser(en_vocab) - parser_d.model, _ = parser_d.Model(0) + parser_d = Parser(en_vocab, default_parser()) parser_d = parser_d.from_disk(file_path) parser_bytes = parser.to_bytes(exclude=["model", "vocab"]) parser_d_bytes = parser_d.to_bytes(exclude=["model", "vocab"]) + assert len(parser_bytes) == len(parser_d_bytes) assert parser_bytes == parser_d_bytes def test_to_from_bytes(parser, blank_parser): assert parser.model is not True - assert blank_parser.model is True + assert blank_parser.model is not True assert blank_parser.moves.n_moves != parser.moves.n_moves bytes_data = parser.to_bytes(exclude=["vocab"]) + + # the blank parser needs to be resized before we can call from_bytes + blank_parser.model.resize_output(parser.moves.n_moves) blank_parser.from_bytes(bytes_data) assert blank_parser.model is not True assert blank_parser.moves.n_moves == parser.moves.n_moves @@ -75,8 +76,10 @@ def test_serialize_tagger_roundtrip_bytes(en_vocab, taggers): tagger1_b = tagger1.to_bytes() tagger1 = tagger1.from_bytes(tagger1_b) assert tagger1.to_bytes() == tagger1_b - new_tagger1 = Tagger(en_vocab).from_bytes(tagger1_b) - assert new_tagger1.to_bytes() == tagger1_b + new_tagger1 = Tagger(en_vocab, default_tagger()).from_bytes(tagger1_b) + new_tagger1_b = new_tagger1.to_bytes() + assert len(new_tagger1_b) == len(tagger1_b) + assert new_tagger1_b == tagger1_b def test_serialize_tagger_roundtrip_disk(en_vocab, taggers): @@ -86,26 +89,24 @@ def test_serialize_tagger_roundtrip_disk(en_vocab, taggers): file_path2 = d / "tagger2" tagger1.to_disk(file_path1) tagger2.to_disk(file_path2) - tagger1_d = Tagger(en_vocab).from_disk(file_path1) - tagger2_d = Tagger(en_vocab).from_disk(file_path2) + tagger1_d = Tagger(en_vocab, default_tagger()).from_disk(file_path1) + tagger2_d = Tagger(en_vocab, default_tagger()).from_disk(file_path2) assert tagger1_d.to_bytes() == tagger2_d.to_bytes() def test_serialize_tensorizer_roundtrip_bytes(en_vocab): - tensorizer = Tensorizer(en_vocab) - tensorizer.model = tensorizer.Model() + tensorizer = Tensorizer(en_vocab, default_tensorizer()) tensorizer_b = tensorizer.to_bytes(exclude=["vocab"]) - new_tensorizer = Tensorizer(en_vocab).from_bytes(tensorizer_b) + new_tensorizer = Tensorizer(en_vocab, default_tensorizer()).from_bytes(tensorizer_b) assert new_tensorizer.to_bytes(exclude=["vocab"]) == tensorizer_b def test_serialize_tensorizer_roundtrip_disk(en_vocab): - tensorizer = Tensorizer(en_vocab) - tensorizer.model = tensorizer.Model() + tensorizer = Tensorizer(en_vocab, default_tensorizer()) with make_tempdir() as d: file_path = d / "tensorizer" tensorizer.to_disk(file_path) - tensorizer_d = Tensorizer(en_vocab).from_disk(file_path) + tensorizer_d = Tensorizer(en_vocab, default_tensorizer()).from_disk(file_path) assert tensorizer.to_bytes(exclude=["vocab"]) == tensorizer_d.to_bytes( exclude=["vocab"] ) @@ -113,19 +114,17 @@ def test_serialize_tensorizer_roundtrip_disk(en_vocab): def test_serialize_textcat_empty(en_vocab): # See issue #1105 - textcat = TextCategorizer(en_vocab, labels=["ENTITY", "ACTION", "MODIFIER"]) + textcat = TextCategorizer(en_vocab, default_textcat(), labels=["ENTITY", "ACTION", "MODIFIER"]) textcat.to_bytes(exclude=["vocab"]) @pytest.mark.parametrize("Parser", test_parsers) def test_serialize_pipe_exclude(en_vocab, Parser): def get_new_parser(): - new_parser = Parser(en_vocab) - new_parser.model, _ = new_parser.Model(0) + new_parser = Parser(en_vocab, default_parser()) return new_parser - parser = Parser(en_vocab) - parser.model, _ = parser.Model(0) + parser = Parser(en_vocab, default_parser()) parser.cfg["foo"] = "bar" new_parser = get_new_parser().from_bytes(parser.to_bytes(exclude=["vocab"])) assert "foo" in new_parser.cfg @@ -144,7 +143,7 @@ def test_serialize_pipe_exclude(en_vocab, Parser): def test_serialize_sentencerecognizer(en_vocab): - sr = SentenceRecognizer(en_vocab) + sr = SentenceRecognizer(en_vocab, default_sentrec()) sr_b = sr.to_bytes() - sr_d = SentenceRecognizer(en_vocab).from_bytes(sr_b) + sr_d = SentenceRecognizer(en_vocab, default_sentrec()).from_bytes(sr_b) assert sr.to_bytes() == sr_d.to_bytes() diff --git a/spacy/tests/test_tok2vec.py b/spacy/tests/test_tok2vec.py index 2d10d79d4..310103d10 100644 --- a/spacy/tests/test_tok2vec.py +++ b/spacy/tests/test_tok2vec.py @@ -1,6 +1,6 @@ import pytest -from spacy.ml.component_models import Tok2Vec +from spacy.ml.models.tok2vec import build_Tok2Vec_model from spacy.vocab import Vocab from spacy.tokens import Doc @@ -25,7 +25,8 @@ def test_empty_doc(): embed_size = 2000 vocab = Vocab() doc = Doc(vocab, words=[]) - tok2vec = Tok2Vec(width, embed_size) + # TODO: fix tok2vec arguments + tok2vec = build_Tok2Vec_model(width, embed_size) vectors, backprop = tok2vec.begin_update([doc]) assert len(vectors) == 1 assert vectors[0].shape == (0, width) @@ -36,7 +37,19 @@ def test_empty_doc(): ) def test_tok2vec_batch_sizes(batch_size, width, embed_size): batch = get_batch(batch_size) - tok2vec = Tok2Vec(width, embed_size) + tok2vec = build_Tok2Vec_model( + width, + embed_size, + pretrained_vectors=None, + conv_depth=4, + bilstm_depth=0, + window_size=1, + maxout_pieces=3, + subword_features=True, + char_embed=False, + nM=64, + nC=8, + ) tok2vec.initialize() vectors, backprop = tok2vec.begin_update(batch) assert len(vectors) == len(batch) @@ -44,19 +57,24 @@ def test_tok2vec_batch_sizes(batch_size, width, embed_size): assert doc_vec.shape == (len(doc), width) +# fmt: off @pytest.mark.parametrize( "tok2vec_config", [ - {"width": 8, "embed_size": 100, "char_embed": False}, - {"width": 8, "embed_size": 100, "char_embed": True}, - {"width": 8, "embed_size": 100, "conv_depth": 6}, - {"width": 8, "embed_size": 100, "conv_depth": 6}, - {"width": 8, "embed_size": 100, "subword_features": False}, + {"width": 8, "embed_size": 100, "char_embed": False, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 1, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": True}, + {"width": 8, "embed_size": 100, "char_embed": True, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 1, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": True}, + {"width": 8, "embed_size": 100, "char_embed": False, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 1, "conv_depth": 6, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": True}, + {"width": 8, "embed_size": 100, "char_embed": False, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 1, "conv_depth": 6, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": True}, + {"width": 8, "embed_size": 100, "char_embed": False, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 1, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": False}, + {"width": 8, "embed_size": 100, "char_embed": False, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 3, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": False}, + {"width": 8, "embed_size": 100, "char_embed": True, "nM": 81, "nC": 8, "pretrained_vectors": None, "window_size": 3, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": False}, + {"width": 8, "embed_size": 100, "char_embed": True, "nM": 81, "nC": 9, "pretrained_vectors": None, "window_size": 3, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": False}, ], ) +# fmt: on def test_tok2vec_configs(tok2vec_config): docs = get_batch(3) - tok2vec = Tok2Vec(**tok2vec_config) + tok2vec = build_Tok2Vec_model(**tok2vec_config) tok2vec.initialize() vectors, backprop = tok2vec.begin_update(docs) assert len(vectors) == len(docs) diff --git a/spacy/util.py b/spacy/util.py index 465b9645e..286a6574c 100644 --- a/spacy/util.py +++ b/spacy/util.py @@ -6,8 +6,7 @@ from pathlib import Path import random from typing import List import thinc -import thinc.config -from thinc.api import NumpyOps, get_current_ops, Adam, require_gpu +from thinc.api import NumpyOps, get_current_ops, Adam, require_gpu, Config import functools import itertools import numpy.random @@ -146,6 +145,10 @@ def load_model_from_path(model_path, meta=False, **overrides): pipeline from meta.json and then calls from_disk() with path.""" if not meta: meta = get_model_meta(model_path) + nlp_config = get_model_config(model_path) + if nlp_config.get("nlp", None): + return load_model_from_config(nlp_config["nlp"]) + # Support language factories registered via entry points (e.g. custom # language subclass) while keeping top-level language identifier "lang" lang = meta.get("lang_factory", meta["lang"]) @@ -162,11 +165,30 @@ def load_model_from_path(model_path, meta=False, **overrides): if name not in disable: config = meta.get("pipeline_args", {}).get(name, {}) factory = factories.get(name, name) + if nlp_config.get(name, None): + model_config = nlp_config[name]["model"] + config["model"] = model_config component = nlp.create_pipe(factory, config=config) nlp.add_pipe(component, name=name) return nlp.from_disk(model_path, exclude=disable) +def load_model_from_config(nlp_config): + if "name" in nlp_config: + nlp = load_model(**nlp_config) + elif "lang" in nlp_config: + lang_class = get_lang_class(nlp_config["lang"]) + nlp = lang_class() + else: + raise ValueError(Errors.E993) + if "pipeline" in nlp_config: + for name, component_cfg in nlp_config["pipeline"].items(): + factory = component_cfg.pop("factory") + component = nlp.create_pipe(factory, config=component_cfg) + nlp.add_pipe(component, name=name) + return nlp + + def load_model_from_init_py(init_file, **overrides): """Helper function to use in the `load()` method of a model package's __init__.py. @@ -184,7 +206,7 @@ def load_model_from_init_py(init_file, **overrides): return load_model_from_path(data_path, meta, **overrides) -def load_from_config(path, create_objects=False): +def load_config(path, create_objects=False): """Load a Thinc-formatted config file, optionally filling in objects where the config references registry entries. See "Thinc config files" for details. @@ -212,7 +234,7 @@ def get_model_meta(path): raise IOError(Errors.E052.format(path=model_path)) meta_path = model_path / "meta.json" if not meta_path.is_file(): - raise IOError(Errors.E053.format(path=meta_path)) + raise IOError(Errors.E053.format(path=meta_path, name="meta.json")) meta = srsly.read_json(meta_path) for setting in ["lang", "name", "version"]: if setting not in meta or not meta[setting]: @@ -220,6 +242,23 @@ def get_model_meta(path): return meta +def get_model_config(path): + """Get the model's config from a directory path. + + path (unicode or Path): Path to model directory. + RETURNS (Config): The model's config data. + """ + model_path = ensure_path(path) + if not model_path.exists(): + raise IOError(Errors.E052.format(path=model_path)) + config_path = model_path / "config.cfg" + # model directories are allowed not to have config files ? + if not config_path.is_file(): + return Config({}) + # raise IOError(Errors.E053.format(path=config_path, name="config.cfg")) + return Config().from_disk(config_path) + + def is_package(name): """Check if string maps to a package installed via pip.