mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 01:48:04 +03:00 
			
		
		
		
	* Add load_from_config function * Add train_from_config script * Merge configs and expose via spacy.config * Fix script * Suggest create_evaluation_callback * Hard-code for NER * Fix errors * Register command * Add TODO * Update train-from-config todos * Fix imports * Allow delayed setting of parser model nr_class * Get train-from-config working * Tidy up and fix scores and printing * Hide traceback if cancelled * Fix weighted score formatting * Fix score formatting * Make output_path optional * Add Tok2Vec component * Tidy up and add tok2vec_tensors * Add option to copy docs in nlp.update * Copy docs in nlp.update * Adjust nlp.update() for set_annotations * Don't shuffle pipes in nlp.update, decruft * Support set_annotations arg in component update * Support set_annotations in parser update * Add get_gradients method * Add get_gradients to parser * Update errors.py * Fix problems caused by merge * Add _link_components method in nlp * Add concept of 'listeners' and ControlledModel * Support optional attributes arg in ControlledModel * Try having tok2vec component in pipeline * Fix tok2vec component * Fix config * Fix tok2vec * Update for Example * Update for Example * Update config * Add eg2doc util * Update and add schemas/types * Update schemas * Fix nlp.update * Fix tagger * Remove hacks from train-from-config * Remove hard-coded config str * Calculate loss in tok2vec component * Tidy up and use function signatures instead of models * Support union types for registry models * Minor cleaning in Language.update * Make ControlledModel specifically Tok2VecListener * Fix train_from_config * Fix tok2vec * Tidy up * Add function for bilstm tok2vec * Fix type * Fix syntax * Fix pytorch optimizer * Add example configs * Update for thinc describe changes * Update for Thinc changes * Update for dropout/sgd changes * Update for dropout/sgd changes * Unhack gradient update * Work on refactoring _ml * Remove _ml.py module * WIP upgrade cli scripts for thinc * Move some _ml stuff to util * Import link_vectors from util * Update train_from_config * Import from util * Import from util * Temporarily add ml.component_models module * Move ml methods * Move typedefs * Update load vectors * Update gitignore * Move imports * Add PrecomputableAffine * Fix imports * Fix imports * Fix imports * Fix missing imports * Update CLI scripts * Update spacy.language * Add stubs for building the models * Update model definition * Update create_default_optimizer * Fix import * Fix comment * Update imports in tests * Update imports in spacy.cli * Fix import * fix obsolete thinc imports * update srsly pin * from thinc to ml_datasets for example data such as imdb * update ml_datasets pin * using STATE.vectors * small fix * fix Sentencizer.pipe * black formatting * rename Affine to Linear as in thinc * set validate explicitely to True * rename with_square_sequences to with_list2padded * rename with_flatten to with_list2array * chaining layernorm * small fixes * revert Optimizer import * build_nel_encoder with new thinc style * fixes using model's get and set methods * Tok2Vec in component models, various fixes * fix up legacy tok2vec code * add model initialize calls * add in build_tagger_model * small fixes * setting model dims * fixes for ParserModel * various small fixes * initialize thinc Models * fixes * consistent naming of window_size * fixes, removing set_dropout * work around Iterable issue * remove legacy tok2vec * util fix * fix forward function of tok2vec listener * more fixes * trying to fix PrecomputableAffine (not succesful yet) * alloc instead of allocate * add morphologizer * rename residual * rename fixes * Fix predict function * Update parser and parser model * fixing few more tests * Fix precomputable affine * Update component model * Update parser model * Move backprop padding to own function, for test * Update test * Fix p. affine * Update NEL * build_bow_text_classifier and extract_ngrams * Fix parser init * Fix test add label * add build_simple_cnn_text_classifier * Fix parser init * Set gpu off by default in example * Fix tok2vec listener * Fix parser model * Small fixes * small fix for PyTorchLSTM parameters * revert my_compounding hack (iterable fixed now) * fix biLSTM * Fix uniqued * PyTorchRNNWrapper fix * small fixes * use helper function to calculate cosine loss * small fixes for build_simple_cnn_text_classifier * putting dropout default at 0.0 to ensure the layer gets built * using thinc util's set_dropout_rate * moving layer normalization inside of maxout definition to optimize dropout * temp debugging in NEL * fixed NEL model by using init defaults ! * fixing after set_dropout_rate refactor * proper fix * fix test_update_doc after refactoring optimizers in thinc * Add CharacterEmbed layer * Construct tagger Model * Add missing import * Remove unused stuff * Work on textcat * fix test (again :)) after optimizer refactor * fixes to allow reading Tagger from_disk without overwriting dimensions * don't build the tok2vec prematuraly * fix CharachterEmbed init * CharacterEmbed fixes * Fix CharacterEmbed architecture * fix imports * renames from latest thinc update * one more rename * add initialize calls where appropriate * fix parser initialization * Update Thinc version * Fix errors, auto-format and tidy up imports * Fix validation * fix if bias is cupy array * revert for now * ensure it's a numpy array before running bp in ParserStepModel * no reason to call require_gpu twice * use CupyOps.to_numpy instead of cupy directly * fix initialize of ParserModel * remove unnecessary import * fixes for CosineDistance * fix device renaming * use refactored loss functions (Thinc PR 251) * overfitting test for tagger * experimental settings for the tagger: avoid zero-init and subword normalization * clean up tagger overfitting test * use previous default value for nP * remove toy config * bringing layernorm back (had a bug - fixed in thinc) * revert setting nP explicitly * remove setting default in constructor * restore values as they used to be * add overfitting test for NER * add overfitting test for dep parser * add overfitting test for textcat * fixing init for linear (previously affine) * larger eps window for textcat * ensure doc is not None * Require newer thinc * Make float check vaguer * Slop the textcat overfit test more * Fix textcat test * Fix exclusive classes for textcat * fix after renaming of alloc methods * fixing renames and mandatory arguments (staticvectors WIP) * upgrade to thinc==8.0.0.dev3 * refer to vocab.vectors directly instead of its name * rename alpha to learn_rate * adding hashembed and staticvectors dropout * upgrade to thinc 8.0.0.dev4 * add name back to avoid warning W020 * thinc dev4 * update srsly * using thinc 8.0.0a0 ! Co-authored-by: Matthew Honnibal <honnibal+gh@gmail.com> Co-authored-by: Ines Montani <ines@ines.io>
		
			
				
	
	
		
			622 lines
		
	
	
		
			27 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			622 lines
		
	
	
		
			27 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import os
 | 
						|
import tqdm
 | 
						|
from pathlib import Path
 | 
						|
from thinc.backends import use_ops
 | 
						|
from timeit import default_timer as timer
 | 
						|
import shutil
 | 
						|
import srsly
 | 
						|
from wasabi import msg
 | 
						|
import contextlib
 | 
						|
import random
 | 
						|
 | 
						|
from ..util import create_default_optimizer
 | 
						|
from ..attrs import PROB, IS_OOV, CLUSTER, LANG
 | 
						|
from ..gold import GoldCorpus
 | 
						|
from .. import util
 | 
						|
from .. import about
 | 
						|
 | 
						|
 | 
						|
def train(
 | 
						|
    # fmt: off
 | 
						|
    lang: ("Model language", "positional", None, str),
 | 
						|
    output_path: ("Output directory to store model in", "positional", None, Path),
 | 
						|
    train_path: ("Location of JSON-formatted training data", "positional", None, Path),
 | 
						|
    dev_path: ("Location of JSON-formatted development data", "positional", None, Path),
 | 
						|
    raw_text: ("Path to jsonl file with unlabelled text documents.", "option", "rt", Path) = None,
 | 
						|
    base_model: ("Name of model to update (optional)", "option", "b", str) = None,
 | 
						|
    pipeline: ("Comma-separated names of pipeline components", "option", "p", str) = "tagger,parser,ner",
 | 
						|
    vectors: ("Model to load vectors from", "option", "v", str) = None,
 | 
						|
    n_iter: ("Number of iterations", "option", "n", int) = 30,
 | 
						|
    n_early_stopping: ("Maximum number of training epochs without dev accuracy improvement", "option", "ne", int) = None,
 | 
						|
    n_examples: ("Number of examples", "option", "ns", int) = 0,
 | 
						|
    use_gpu: ("Use GPU", "option", "g", int) = -1,
 | 
						|
    version: ("Model version", "option", "V", str) = "0.0.0",
 | 
						|
    meta_path: ("Optional path to meta.json to use as base.", "option", "m", Path) = None,
 | 
						|
    init_tok2vec: ("Path to pretrained weights for the token-to-vector parts of the models. See 'spacy pretrain'. Experimental.", "option", "t2v", Path) = None,
 | 
						|
    parser_multitasks: ("Side objectives for parser CNN, e.g. 'dep' or 'dep,tag'", "option", "pt", str) = "",
 | 
						|
    entity_multitasks: ("Side objectives for NER CNN, e.g. 'dep' or 'dep,tag'", "option", "et", str) = "",
 | 
						|
    noise_level: ("Amount of corruption for data augmentation", "option", "nl", float) = 0.0,
 | 
						|
    orth_variant_level: ("Amount of orthography variation for data augmentation", "option", "ovl", float) = 0.0,
 | 
						|
    eval_beam_widths: ("Beam widths to evaluate, e.g. 4,8", "option", "bw", str) = "",
 | 
						|
    gold_preproc: ("Use gold preprocessing", "flag", "G", bool) = False,
 | 
						|
    learn_tokens: ("Make parser learn gold-standard tokenization", "flag", "T", bool) = False,
 | 
						|
    textcat_multilabel: ("Textcat classes aren't mutually exclusive (multilabel)", "flag", "TML", bool) = False,
 | 
						|
    textcat_arch: ("Textcat model architecture", "option", "ta", str) = "bow",
 | 
						|
    textcat_positive_label: ("Textcat positive label for binary classes with two labels", "option", "tpl", str) = None,
 | 
						|
    tag_map_path: ("Location of JSON-formatted tag map", "option", "tm", Path) = None,
 | 
						|
    verbose: ("Display more information for debug", "flag", "VV", bool) = False,
 | 
						|
    debug: ("Run data diagnostics before training", "flag", "D", bool) = False,
 | 
						|
    # fmt: on
 | 
						|
):
 | 
						|
    """
 | 
						|
    Train or update a spaCy model. Requires data to be formatted in spaCy's
 | 
						|
    JSON format. To convert data from other formats, use the `spacy convert`
 | 
						|
    command.
 | 
						|
    """
 | 
						|
    util.fix_random_seed()
 | 
						|
    util.set_env_log(verbose)
 | 
						|
 | 
						|
    # Make sure all files and paths exists if they are needed
 | 
						|
    train_path = util.ensure_path(train_path)
 | 
						|
    dev_path = util.ensure_path(dev_path)
 | 
						|
    meta_path = util.ensure_path(meta_path)
 | 
						|
    output_path = util.ensure_path(output_path)
 | 
						|
    if raw_text is not None:
 | 
						|
        raw_text = list(srsly.read_jsonl(raw_text))
 | 
						|
    if not train_path or not train_path.exists():
 | 
						|
        msg.fail("Training data not found", train_path, exits=1)
 | 
						|
    if not dev_path or not dev_path.exists():
 | 
						|
        msg.fail("Development data not found", dev_path, exits=1)
 | 
						|
    if meta_path is not None and not meta_path.exists():
 | 
						|
        msg.fail("Can't find model meta.json", meta_path, exits=1)
 | 
						|
    meta = srsly.read_json(meta_path) if meta_path else {}
 | 
						|
    if output_path.exists() and [p for p in output_path.iterdir() if p.is_dir()]:
 | 
						|
        msg.warn(
 | 
						|
            "Output directory is not empty",
 | 
						|
            "This can lead to unintended side effects when saving the model. "
 | 
						|
            "Please use an empty directory or a different path instead. If "
 | 
						|
            "the specified output path doesn't exist, the directory will be "
 | 
						|
            "created for you.",
 | 
						|
        )
 | 
						|
    if not output_path.exists():
 | 
						|
        output_path.mkdir()
 | 
						|
 | 
						|
    tag_map = {}
 | 
						|
    if tag_map_path is not None:
 | 
						|
        tag_map = srsly.read_json(tag_map_path)
 | 
						|
    # Take dropout and batch size as generators of values -- dropout
 | 
						|
    # starts high and decays sharply, to force the optimizer to explore.
 | 
						|
    # Batch size starts at 1 and grows, so that we make updates quickly
 | 
						|
    # at the beginning of training.
 | 
						|
    dropout_rates = util.decaying(
 | 
						|
        util.env_opt("dropout_from", 0.2),
 | 
						|
        util.env_opt("dropout_to", 0.2),
 | 
						|
        util.env_opt("dropout_decay", 0.0),
 | 
						|
    )
 | 
						|
    batch_sizes = util.compounding(
 | 
						|
        util.env_opt("batch_from", 100.0),
 | 
						|
        util.env_opt("batch_to", 1000.0),
 | 
						|
        util.env_opt("batch_compound", 1.001),
 | 
						|
    )
 | 
						|
 | 
						|
    if not eval_beam_widths:
 | 
						|
        eval_beam_widths = [1]
 | 
						|
    else:
 | 
						|
        eval_beam_widths = [int(bw) for bw in eval_beam_widths.split(",")]
 | 
						|
        if 1 not in eval_beam_widths:
 | 
						|
            eval_beam_widths.append(1)
 | 
						|
        eval_beam_widths.sort()
 | 
						|
    has_beam_widths = eval_beam_widths != [1]
 | 
						|
 | 
						|
    # Set up the base model and pipeline. If a base model is specified, load
 | 
						|
    # the model and make sure the pipeline matches the pipeline setting. If
 | 
						|
    # training starts from a blank model, intitalize the language class.
 | 
						|
    pipeline = [p.strip() for p in pipeline.split(",")]
 | 
						|
    msg.text(f"Training pipeline: {pipeline}")
 | 
						|
    if base_model:
 | 
						|
        msg.text(f"Starting with base model '{base_model}'")
 | 
						|
        nlp = util.load_model(base_model)
 | 
						|
        if nlp.lang != lang:
 | 
						|
            msg.fail(
 | 
						|
                f"Model language ('{nlp.lang}') doesn't match language "
 | 
						|
                f"specified as `lang` argument ('{lang}') ",
 | 
						|
                exits=1,
 | 
						|
            )
 | 
						|
        nlp.disable_pipes([p for p in nlp.pipe_names if p not in pipeline])
 | 
						|
        for pipe in pipeline:
 | 
						|
            if pipe not in nlp.pipe_names:
 | 
						|
                if pipe == "parser":
 | 
						|
                    pipe_cfg = {"learn_tokens": learn_tokens}
 | 
						|
                elif pipe == "textcat":
 | 
						|
                    pipe_cfg = {
 | 
						|
                        "exclusive_classes": not textcat_multilabel,
 | 
						|
                        "architecture": textcat_arch,
 | 
						|
                        "positive_label": textcat_positive_label,
 | 
						|
                    }
 | 
						|
                else:
 | 
						|
                    pipe_cfg = {}
 | 
						|
                nlp.add_pipe(nlp.create_pipe(pipe, config=pipe_cfg))
 | 
						|
            else:
 | 
						|
                if pipe == "textcat":
 | 
						|
                    textcat_cfg = nlp.get_pipe("textcat").cfg
 | 
						|
                    base_cfg = {
 | 
						|
                        "exclusive_classes": textcat_cfg["exclusive_classes"],
 | 
						|
                        "architecture": textcat_cfg["architecture"],
 | 
						|
                        "positive_label": textcat_cfg["positive_label"],
 | 
						|
                    }
 | 
						|
                    pipe_cfg = {
 | 
						|
                        "exclusive_classes": not textcat_multilabel,
 | 
						|
                        "architecture": textcat_arch,
 | 
						|
                        "positive_label": textcat_positive_label,
 | 
						|
                    }
 | 
						|
                    if base_cfg != pipe_cfg:
 | 
						|
                        msg.fail(
 | 
						|
                            f"The base textcat model configuration does"
 | 
						|
                            f"not match the provided training options. "
 | 
						|
                            f"Existing cfg: {base_cfg}, provided cfg: {pipe_cfg}",
 | 
						|
                            exits=1,
 | 
						|
                        )
 | 
						|
    else:
 | 
						|
        msg.text(f"Starting with blank model '{lang}'")
 | 
						|
        lang_cls = util.get_lang_class(lang)
 | 
						|
        nlp = lang_cls()
 | 
						|
        for pipe in pipeline:
 | 
						|
            if pipe == "parser":
 | 
						|
                pipe_cfg = {"learn_tokens": learn_tokens}
 | 
						|
            elif pipe == "textcat":
 | 
						|
                pipe_cfg = {
 | 
						|
                    "exclusive_classes": not textcat_multilabel,
 | 
						|
                    "architecture": textcat_arch,
 | 
						|
                    "positive_label": textcat_positive_label,
 | 
						|
                }
 | 
						|
            else:
 | 
						|
                pipe_cfg = {}
 | 
						|
            nlp.add_pipe(nlp.create_pipe(pipe, config=pipe_cfg))
 | 
						|
 | 
						|
    # Update tag map with provided mapping
 | 
						|
    nlp.vocab.morphology.tag_map.update(tag_map)
 | 
						|
 | 
						|
    if vectors:
 | 
						|
        msg.text(f"Loading vector from model '{vectors}'")
 | 
						|
        _load_vectors(nlp, vectors)
 | 
						|
 | 
						|
    # Multitask objectives
 | 
						|
    multitask_options = [("parser", parser_multitasks), ("ner", entity_multitasks)]
 | 
						|
    for pipe_name, multitasks in multitask_options:
 | 
						|
        if multitasks:
 | 
						|
            if pipe_name not in pipeline:
 | 
						|
                msg.fail(
 | 
						|
                    f"Can't use multitask objective without '{pipe_name}' in "
 | 
						|
                    f"the pipeline"
 | 
						|
                )
 | 
						|
            pipe = nlp.get_pipe(pipe_name)
 | 
						|
            for objective in multitasks.split(","):
 | 
						|
                pipe.add_multitask_objective(objective)
 | 
						|
 | 
						|
    # Prepare training corpus
 | 
						|
    msg.text(f"Counting training words (limit={n_examples})")
 | 
						|
    corpus = GoldCorpus(train_path, dev_path, limit=n_examples)
 | 
						|
    n_train_words = corpus.count_train()
 | 
						|
 | 
						|
    if base_model:
 | 
						|
        # Start with an existing model, use default optimizer
 | 
						|
        optimizer = create_default_optimizer()
 | 
						|
    else:
 | 
						|
        # Start with a blank model, call begin_training
 | 
						|
        optimizer = nlp.begin_training(lambda: corpus.train_examples, device=use_gpu)
 | 
						|
 | 
						|
    nlp._optimizer = None
 | 
						|
 | 
						|
    # Load in pretrained weights
 | 
						|
    if init_tok2vec is not None:
 | 
						|
        components = _load_pretrained_tok2vec(nlp, init_tok2vec)
 | 
						|
        msg.text(f"Loaded pretrained tok2vec for: {components}")
 | 
						|
 | 
						|
    # Verify textcat config
 | 
						|
    if "textcat" in pipeline:
 | 
						|
        textcat_labels = nlp.get_pipe("textcat").cfg["labels"]
 | 
						|
        if textcat_positive_label and textcat_positive_label not in textcat_labels:
 | 
						|
            msg.fail(
 | 
						|
                f"The textcat_positive_label (tpl) '{textcat_positive_label}' "
 | 
						|
                f"does not match any label in the training data.",
 | 
						|
                exits=1,
 | 
						|
            )
 | 
						|
        if textcat_positive_label and len(textcat_labels) != 2:
 | 
						|
            msg.fail(
 | 
						|
                "A textcat_positive_label (tpl) '{textcat_positive_label}' was "
 | 
						|
                "provided for training data that does not appear to be a "
 | 
						|
                "binary classification problem with two labels.",
 | 
						|
                exits=1,
 | 
						|
            )
 | 
						|
        train_data = corpus.train_data(
 | 
						|
            nlp,
 | 
						|
            noise_level=noise_level,
 | 
						|
            gold_preproc=gold_preproc,
 | 
						|
            max_length=0,
 | 
						|
            ignore_misaligned=True,
 | 
						|
        )
 | 
						|
        train_labels = set()
 | 
						|
        if textcat_multilabel:
 | 
						|
            multilabel_found = False
 | 
						|
            for ex in train_data:
 | 
						|
                train_labels.update(ex.gold.cats.keys())
 | 
						|
                if list(ex.gold.cats.values()).count(1.0) != 1:
 | 
						|
                    multilabel_found = True
 | 
						|
            if not multilabel_found and not base_model:
 | 
						|
                msg.warn(
 | 
						|
                    "The textcat training instances look like they have "
 | 
						|
                    "mutually-exclusive classes. Remove the flag "
 | 
						|
                    "'--textcat-multilabel' to train a classifier with "
 | 
						|
                    "mutually-exclusive classes."
 | 
						|
                )
 | 
						|
        if not textcat_multilabel:
 | 
						|
            for ex in train_data:
 | 
						|
                train_labels.update(ex.gold.cats.keys())
 | 
						|
                if list(ex.gold.cats.values()).count(1.0) != 1 and not base_model:
 | 
						|
                    msg.warn(
 | 
						|
                        "Some textcat training instances do not have exactly "
 | 
						|
                        "one positive label. Modifying training options to "
 | 
						|
                        "include the flag '--textcat-multilabel' for classes "
 | 
						|
                        "that are not mutually exclusive."
 | 
						|
                    )
 | 
						|
                    nlp.get_pipe("textcat").cfg["exclusive_classes"] = False
 | 
						|
                    textcat_multilabel = True
 | 
						|
                    break
 | 
						|
        if base_model and set(textcat_labels) != train_labels:
 | 
						|
            msg.fail(
 | 
						|
                f"Cannot extend textcat model using data with different "
 | 
						|
                f"labels. Base model labels: {textcat_labels}, training data "
 | 
						|
                f"labels: {list(train_labels)}",
 | 
						|
                exits=1,
 | 
						|
            )
 | 
						|
        if textcat_multilabel:
 | 
						|
            msg.text(
 | 
						|
                f"Textcat evaluation score: ROC AUC score macro-averaged across "
 | 
						|
                f"the labels '{', '.join(textcat_labels)}'"
 | 
						|
            )
 | 
						|
        elif textcat_positive_label and len(textcat_labels) == 2:
 | 
						|
            msg.text(
 | 
						|
                f"Textcat evaluation score: F1-score for the "
 | 
						|
                f"label '{textcat_positive_label}'"
 | 
						|
            )
 | 
						|
        elif len(textcat_labels) > 1:
 | 
						|
            if len(textcat_labels) == 2:
 | 
						|
                msg.warn(
 | 
						|
                    "If the textcat component is a binary classifier with "
 | 
						|
                    "exclusive classes, provide '--textcat_positive_label' for "
 | 
						|
                    "an evaluation on the positive class."
 | 
						|
                )
 | 
						|
            msg.text(
 | 
						|
                f"Textcat evaluation score: F1-score macro-averaged across "
 | 
						|
                f"the labels '{', '.join(textcat_labels)}'"
 | 
						|
            )
 | 
						|
        else:
 | 
						|
            msg.fail(
 | 
						|
                "Unsupported textcat configuration. Use `spacy debug-data` "
 | 
						|
                "for more information."
 | 
						|
            )
 | 
						|
 | 
						|
    # fmt: off
 | 
						|
    row_head, output_stats = _configure_training_output(pipeline, use_gpu, has_beam_widths)
 | 
						|
    row_widths = [len(w) for w in row_head]
 | 
						|
    row_settings = {"widths": row_widths, "aligns": tuple(["r" for i in row_head]), "spacing": 2}
 | 
						|
    # fmt: on
 | 
						|
    print("")
 | 
						|
    msg.row(row_head, **row_settings)
 | 
						|
    msg.row(["-" * width for width in row_settings["widths"]], **row_settings)
 | 
						|
    try:
 | 
						|
        iter_since_best = 0
 | 
						|
        best_score = 0.0
 | 
						|
        for i in range(n_iter):
 | 
						|
            train_data = corpus.train_dataset(
 | 
						|
                nlp,
 | 
						|
                noise_level=noise_level,
 | 
						|
                orth_variant_level=orth_variant_level,
 | 
						|
                gold_preproc=gold_preproc,
 | 
						|
                max_length=0,
 | 
						|
                ignore_misaligned=True,
 | 
						|
            )
 | 
						|
            if raw_text:
 | 
						|
                random.shuffle(raw_text)
 | 
						|
                raw_batches = util.minibatch(
 | 
						|
                    (nlp.make_doc(rt["text"]) for rt in raw_text), size=8
 | 
						|
                )
 | 
						|
            words_seen = 0
 | 
						|
            with tqdm.tqdm(total=n_train_words, leave=False) as pbar:
 | 
						|
                losses = {}
 | 
						|
                for batch in util.minibatch_by_words(train_data, size=batch_sizes):
 | 
						|
                    if not batch:
 | 
						|
                        continue
 | 
						|
                    nlp.update(
 | 
						|
                        batch,
 | 
						|
                        sgd=optimizer,
 | 
						|
                        drop=next(dropout_rates),
 | 
						|
                        losses=losses,
 | 
						|
                    )
 | 
						|
                    if raw_text:
 | 
						|
                        # If raw text is available, perform 'rehearsal' updates,
 | 
						|
                        # which use unlabelled data to reduce overfitting.
 | 
						|
                        raw_batch = list(next(raw_batches))
 | 
						|
                        nlp.rehearse(raw_batch, sgd=optimizer, losses=losses)
 | 
						|
                    docs = [ex.doc for ex in batch]
 | 
						|
                    if not int(os.environ.get("LOG_FRIENDLY", 0)):
 | 
						|
                        pbar.update(sum(len(doc) for doc in docs))
 | 
						|
                    words_seen += sum(len(doc) for doc in docs)
 | 
						|
            with nlp.use_params(optimizer.averages):
 | 
						|
                util.set_env_log(False)
 | 
						|
                epoch_model_path = output_path / f"model{i}"
 | 
						|
                nlp.to_disk(epoch_model_path)
 | 
						|
                nlp_loaded = util.load_model_from_path(epoch_model_path)
 | 
						|
                for beam_width in eval_beam_widths:
 | 
						|
                    for name, component in nlp_loaded.pipeline:
 | 
						|
                        if hasattr(component, "cfg"):
 | 
						|
                            component.cfg["beam_width"] = beam_width
 | 
						|
                    dev_dataset = list(
 | 
						|
                        corpus.dev_dataset(
 | 
						|
                            nlp_loaded,
 | 
						|
                            gold_preproc=gold_preproc,
 | 
						|
                            ignore_misaligned=True,
 | 
						|
                        )
 | 
						|
                    )
 | 
						|
                    nwords = sum(len(ex.doc) for ex in dev_dataset)
 | 
						|
                    start_time = timer()
 | 
						|
                    scorer = nlp_loaded.evaluate(dev_dataset, verbose=verbose)
 | 
						|
                    end_time = timer()
 | 
						|
                    if use_gpu < 0:
 | 
						|
                        gpu_wps = None
 | 
						|
                        cpu_wps = nwords / (end_time - start_time)
 | 
						|
                    else:
 | 
						|
                        gpu_wps = nwords / (end_time - start_time)
 | 
						|
                        with use_ops("numpy"):
 | 
						|
                            nlp_loaded = util.load_model_from_path(epoch_model_path)
 | 
						|
                            for name, component in nlp_loaded.pipeline:
 | 
						|
                                if hasattr(component, "cfg"):
 | 
						|
                                    component.cfg["beam_width"] = beam_width
 | 
						|
                            dev_dataset = list(
 | 
						|
                                corpus.dev_dataset(
 | 
						|
                                    nlp_loaded,
 | 
						|
                                    gold_preproc=gold_preproc,
 | 
						|
                                    ignore_misaligned=True,
 | 
						|
                                )
 | 
						|
                            )
 | 
						|
                            start_time = timer()
 | 
						|
                            scorer = nlp_loaded.evaluate(dev_dataset, verbose=verbose)
 | 
						|
                            end_time = timer()
 | 
						|
                            cpu_wps = nwords / (end_time - start_time)
 | 
						|
                    acc_loc = output_path / f"model{i}" / "accuracy.json"
 | 
						|
                    srsly.write_json(acc_loc, scorer.scores)
 | 
						|
 | 
						|
                    # Update model meta.json
 | 
						|
                    meta["lang"] = nlp.lang
 | 
						|
                    meta["pipeline"] = nlp.pipe_names
 | 
						|
                    meta["spacy_version"] = f">={about.__version__}"
 | 
						|
                    if beam_width == 1:
 | 
						|
                        meta["speed"] = {
 | 
						|
                            "nwords": nwords,
 | 
						|
                            "cpu": cpu_wps,
 | 
						|
                            "gpu": gpu_wps,
 | 
						|
                        }
 | 
						|
                        meta["accuracy"] = scorer.scores
 | 
						|
                    else:
 | 
						|
                        meta.setdefault("beam_accuracy", {})
 | 
						|
                        meta.setdefault("beam_speed", {})
 | 
						|
                        meta["beam_accuracy"][beam_width] = scorer.scores
 | 
						|
                        meta["beam_speed"][beam_width] = {
 | 
						|
                            "nwords": nwords,
 | 
						|
                            "cpu": cpu_wps,
 | 
						|
                            "gpu": gpu_wps,
 | 
						|
                        }
 | 
						|
                    meta["vectors"] = {
 | 
						|
                        "width": nlp.vocab.vectors_length,
 | 
						|
                        "vectors": len(nlp.vocab.vectors),
 | 
						|
                        "keys": nlp.vocab.vectors.n_keys,
 | 
						|
                        "name": nlp.vocab.vectors.name,
 | 
						|
                    }
 | 
						|
                    meta.setdefault("name", f"model{i}")
 | 
						|
                    meta.setdefault("version", version)
 | 
						|
                    meta["labels"] = nlp.meta["labels"]
 | 
						|
                    meta_loc = output_path / f"model{i}" / "meta.json"
 | 
						|
                    srsly.write_json(meta_loc, meta)
 | 
						|
                    util.set_env_log(verbose)
 | 
						|
 | 
						|
                    progress = _get_progress(
 | 
						|
                        i,
 | 
						|
                        losses,
 | 
						|
                        scorer.scores,
 | 
						|
                        output_stats,
 | 
						|
                        beam_width=beam_width if has_beam_widths else None,
 | 
						|
                        cpu_wps=cpu_wps,
 | 
						|
                        gpu_wps=gpu_wps,
 | 
						|
                    )
 | 
						|
                    if i == 0 and "textcat" in pipeline:
 | 
						|
                        textcats_per_cat = scorer.scores.get("textcats_per_cat", {})
 | 
						|
                        for cat, cat_score in textcats_per_cat.items():
 | 
						|
                            if cat_score.get("roc_auc_score", 0) < 0:
 | 
						|
                                msg.warn(
 | 
						|
                                    f"Textcat ROC AUC score is undefined due to "
 | 
						|
                                    f"only one value in label '{cat}'."
 | 
						|
                                )
 | 
						|
                    msg.row(progress, **row_settings)
 | 
						|
                # Early stopping
 | 
						|
                if n_early_stopping is not None:
 | 
						|
                    current_score = _score_for_model(meta)
 | 
						|
                    if current_score < best_score:
 | 
						|
                        iter_since_best += 1
 | 
						|
                    else:
 | 
						|
                        iter_since_best = 0
 | 
						|
                        best_score = current_score
 | 
						|
                    if iter_since_best >= n_early_stopping:
 | 
						|
                        msg.text(
 | 
						|
                            f"Early stopping, best iteration is: {i - iter_since_best}"
 | 
						|
                        )
 | 
						|
                        msg.text(
 | 
						|
                            f"Best score = {best_score}; Final iteration score = {current_score}"
 | 
						|
                        )
 | 
						|
                        break
 | 
						|
    finally:
 | 
						|
        with nlp.use_params(optimizer.averages):
 | 
						|
            final_model_path = output_path / "model-final"
 | 
						|
            nlp.to_disk(final_model_path)
 | 
						|
        msg.good("Saved model to output directory", final_model_path)
 | 
						|
        with msg.loading("Creating best model..."):
 | 
						|
            best_model_path = _collate_best_model(meta, output_path, nlp.pipe_names)
 | 
						|
        msg.good("Created best model", best_model_path)
 | 
						|
 | 
						|
 | 
						|
def _score_for_model(meta):
 | 
						|
    """ Returns mean score between tasks in pipeline that can be used for early stopping. """
 | 
						|
    mean_acc = list()
 | 
						|
    pipes = meta["pipeline"]
 | 
						|
    acc = meta["accuracy"]
 | 
						|
    if "tagger" in pipes:
 | 
						|
        mean_acc.append(acc["tags_acc"])
 | 
						|
    if "parser" in pipes:
 | 
						|
        mean_acc.append((acc["uas"] + acc["las"]) / 2)
 | 
						|
    if "ner" in pipes:
 | 
						|
        mean_acc.append((acc["ents_p"] + acc["ents_r"] + acc["ents_f"]) / 3)
 | 
						|
    if "textcat" in pipes:
 | 
						|
        mean_acc.append(acc["textcat_score"])
 | 
						|
    if "sentrec" in pipes:
 | 
						|
        mean_acc.append((acc["sent_p"] + acc["sent_r"] + acc["sent_f"]) / 3)
 | 
						|
    return sum(mean_acc) / len(mean_acc)
 | 
						|
 | 
						|
 | 
						|
@contextlib.contextmanager
 | 
						|
def _create_progress_bar(total):
 | 
						|
    if int(os.environ.get("LOG_FRIENDLY", 0)):
 | 
						|
        yield
 | 
						|
    else:
 | 
						|
        pbar = tqdm.tqdm(total=total, leave=False)
 | 
						|
        yield pbar
 | 
						|
 | 
						|
 | 
						|
def _load_vectors(nlp, vectors):
 | 
						|
    util.load_model(vectors, vocab=nlp.vocab)
 | 
						|
    for lex in nlp.vocab:
 | 
						|
        values = {}
 | 
						|
        for attr, func in nlp.vocab.lex_attr_getters.items():
 | 
						|
            # These attrs are expected to be set by data. Others should
 | 
						|
            # be set by calling the language functions.
 | 
						|
            if attr not in (CLUSTER, PROB, IS_OOV, LANG):
 | 
						|
                values[lex.vocab.strings[attr]] = func(lex.orth_)
 | 
						|
        lex.set_attrs(**values)
 | 
						|
        lex.is_oov = False
 | 
						|
 | 
						|
 | 
						|
def _load_pretrained_tok2vec(nlp, loc):
 | 
						|
    """Load pretrained weights for the 'token-to-vector' part of the component
 | 
						|
    models, which is typically a CNN. See 'spacy pretrain'. Experimental.
 | 
						|
    """
 | 
						|
    with loc.open("rb") as file_:
 | 
						|
        weights_data = file_.read()
 | 
						|
    loaded = []
 | 
						|
    for name, component in nlp.pipeline:
 | 
						|
        if hasattr(component, "model") and hasattr(component.model, "tok2vec"):
 | 
						|
            component.tok2vec.from_bytes(weights_data)
 | 
						|
            loaded.append(name)
 | 
						|
    return loaded
 | 
						|
 | 
						|
 | 
						|
def _collate_best_model(meta, output_path, components):
 | 
						|
    bests = {}
 | 
						|
    for component in components:
 | 
						|
        bests[component] = _find_best(output_path, component)
 | 
						|
    best_dest = output_path / "model-best"
 | 
						|
    shutil.copytree(str(output_path / "model-final"), str(best_dest))
 | 
						|
    for component, best_component_src in bests.items():
 | 
						|
        shutil.rmtree(str(best_dest / component))
 | 
						|
        shutil.copytree(
 | 
						|
            str(best_component_src / component), str(best_dest / component)
 | 
						|
        )
 | 
						|
        accs = srsly.read_json(best_component_src / "accuracy.json")
 | 
						|
        for metric in _get_metrics(component):
 | 
						|
            meta["accuracy"][metric] = accs[metric]
 | 
						|
    srsly.write_json(best_dest / "meta.json", meta)
 | 
						|
    return best_dest
 | 
						|
 | 
						|
 | 
						|
def _find_best(experiment_dir, component):
 | 
						|
    accuracies = []
 | 
						|
    for epoch_model in experiment_dir.iterdir():
 | 
						|
        if epoch_model.is_dir() and epoch_model.parts[-1] != "model-final":
 | 
						|
            accs = srsly.read_json(epoch_model / "accuracy.json")
 | 
						|
            scores = [accs.get(metric, 0.0) for metric in _get_metrics(component)]
 | 
						|
            accuracies.append((scores, epoch_model))
 | 
						|
    if accuracies:
 | 
						|
        return max(accuracies)[1]
 | 
						|
    else:
 | 
						|
        return None
 | 
						|
 | 
						|
 | 
						|
def _get_metrics(component):
 | 
						|
    if component == "parser":
 | 
						|
        return ("las", "uas", "token_acc", "sent_f")
 | 
						|
    elif component == "tagger":
 | 
						|
        return ("tags_acc",)
 | 
						|
    elif component == "ner":
 | 
						|
        return ("ents_f", "ents_p", "ents_r")
 | 
						|
    elif component == "sentrec":
 | 
						|
        return ("sent_f", "sent_p", "sent_r")
 | 
						|
    return ("token_acc",)
 | 
						|
 | 
						|
 | 
						|
def _configure_training_output(pipeline, use_gpu, has_beam_widths):
 | 
						|
    row_head = ["Itn"]
 | 
						|
    output_stats = []
 | 
						|
    for pipe in pipeline:
 | 
						|
        if pipe == "tagger":
 | 
						|
            row_head.extend(["Tag Loss ", " Tag %  "])
 | 
						|
            output_stats.extend(["tag_loss", "tags_acc"])
 | 
						|
        elif pipe == "parser":
 | 
						|
            row_head.extend(["Dep Loss ", " UAS  ", " LAS  ", "Sent P", "Sent R", "Sent F"])
 | 
						|
            output_stats.extend(["dep_loss", "uas", "las", "sent_p", "sent_r", "sent_f"])
 | 
						|
        elif pipe == "ner":
 | 
						|
            row_head.extend(["NER Loss ", "NER P ", "NER R ", "NER F "])
 | 
						|
            output_stats.extend(["ner_loss", "ents_p", "ents_r", "ents_f"])
 | 
						|
        elif pipe == "textcat":
 | 
						|
            row_head.extend(["Textcat Loss", "Textcat"])
 | 
						|
            output_stats.extend(["textcat_loss", "textcat_score"])
 | 
						|
        elif pipe == "sentrec":
 | 
						|
            row_head.extend(["Sentrec Loss", "Sent P", "Sent R", "Sent F"])
 | 
						|
            output_stats.extend(["sentrec_loss", "sent_p", "sent_r", "sent_f"])
 | 
						|
    row_head.extend(["Token %", "CPU WPS"])
 | 
						|
    output_stats.extend(["token_acc", "cpu_wps"])
 | 
						|
 | 
						|
    if use_gpu >= 0:
 | 
						|
        row_head.extend(["GPU WPS"])
 | 
						|
        output_stats.extend(["gpu_wps"])
 | 
						|
 | 
						|
    if has_beam_widths:
 | 
						|
        row_head.insert(1, "Beam W.")
 | 
						|
    # remove duplicates
 | 
						|
    row_head_dict = {k: 1 for k in row_head}
 | 
						|
    output_stats_dict = {k: 1 for k in output_stats}
 | 
						|
    return row_head_dict.keys(), output_stats_dict.keys()
 | 
						|
 | 
						|
 | 
						|
def _get_progress(
 | 
						|
    itn, losses, dev_scores, output_stats, beam_width=None, cpu_wps=0.0, gpu_wps=0.0
 | 
						|
):
 | 
						|
    scores = {}
 | 
						|
    for stat in output_stats:
 | 
						|
        scores[stat] = 0.0
 | 
						|
    scores["dep_loss"] = losses.get("parser", 0.0)
 | 
						|
    scores["ner_loss"] = losses.get("ner", 0.0)
 | 
						|
    scores["tag_loss"] = losses.get("tagger", 0.0)
 | 
						|
    scores["textcat_loss"] = losses.get("textcat", 0.0)
 | 
						|
    scores["sentrec_loss"] = losses.get("sentrec", 0.0)
 | 
						|
    scores["cpu_wps"] = cpu_wps
 | 
						|
    scores["gpu_wps"] = gpu_wps or 0.0
 | 
						|
    scores.update(dev_scores)
 | 
						|
    formatted_scores = []
 | 
						|
    for stat in output_stats:
 | 
						|
        format_spec = "{:.3f}"
 | 
						|
        if stat.endswith("_wps"):
 | 
						|
            format_spec = "{:.0f}"
 | 
						|
        formatted_scores.append(format_spec.format(scores[stat]))
 | 
						|
    result = [itn + 1]
 | 
						|
    result.extend(formatted_scores)
 | 
						|
    if beam_width is not None:
 | 
						|
        result.insert(1, beam_width)
 | 
						|
    return result
 |