From 97bcf2ae3a03cf97fd473062c1793e3d18ef2820 Mon Sep 17 00:00:00 2001 From: Adriane Boyd Date: Sat, 6 Mar 2021 08:42:14 +0100 Subject: [PATCH 01/10] Fix patience for identical scores (#7250) * Fix patience for identical scores Fix training patience so that the earliest best step is chosen for identical max scores. * Restore break, remove print * Explicitly define best_step for clarity --- spacy/training/loop.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/spacy/training/loop.py b/spacy/training/loop.py index dacd2dba4..55919014b 100644 --- a/spacy/training/loop.py +++ b/spacy/training/loop.py @@ -230,7 +230,10 @@ def train_while_improving( if is_best_checkpoint is not None: losses = {} # Stop if no improvement in `patience` updates (if specified) - best_score, best_step = max(results) + # Negate step value so that the earliest best step is chosen for the + # same score, i.e. (1.0, 100) is chosen over (1.0, 200) + best_result = max((r_score, -r_step) for r_score, r_step in results) + best_step = -best_result[1] if patience and (step - best_step) >= patience: break # Stop if we've exhausted our max steps (if specified) From cd70c3cb791b0e9a9e6323d8b79a267ba11284c9 Mon Sep 17 00:00:00 2001 From: Sofie Van Landeghem Date: Tue, 9 Mar 2021 04:01:13 +0100 Subject: [PATCH 02/10] Fixing pretrain (#7342) * initialize NLP with train corpus * add more pretraining tests * more tests * function to fetch tok2vec layer for pretraining * clarify parameter name * test different objectives * formatting * fix check for static vectors when using vectors objective * clarify docs * logger statement * fix init_tok2vec and proc.initialize order * test training after pretraining * add init_config tests for pretraining * pop pretraining block to avoid config validation errors * custom errors --- spacy/errors.py | 5 +- spacy/language.py | 12 +- spacy/ml/models/multi_task.py | 4 +- .../tests/serialize/test_serialize_config.py | 2 +- spacy/tests/test_cli.py | 10 +- spacy/tests/training/test_pretraining.py | 345 ++++++++++++++++++ spacy/training/initialize.py | 21 +- spacy/training/pretrain.py | 42 ++- website/docs/api/architectures.md | 7 +- 9 files changed, 413 insertions(+), 35 deletions(-) create mode 100644 spacy/tests/training/test_pretraining.py diff --git a/spacy/errors.py b/spacy/errors.py index 2ebc49e8c..4f61cf098 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -487,7 +487,10 @@ class Errors: E202 = ("Unsupported alignment mode '{mode}'. Supported modes: {modes}.") # New errors added in v3.x - + E874 = ("Could not initialize the tok2vec model from component " + "'{component}' and layer '{layer}'.") + E875 = ("To use the PretrainVectors objective, make sure that static vectors are loaded. " + "In the config, these are defined by the initialize.vectors setting.") E879 = ("Unexpected type for 'spans' data. Provide a dictionary mapping keys to " "a list of spans, with each span represented by a tuple (start_char, end_char). " "The tuple can be optionally extended with a label and a KB ID.") diff --git a/spacy/language.py b/spacy/language.py index 2a9b50bcc..80de94278 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -1222,10 +1222,6 @@ class Language: init_vocab( self, data=I["vocab_data"], lookups=I["lookups"], vectors=I["vectors"] ) - pretrain_cfg = config.get("pretraining") - if pretrain_cfg: - P = registry.resolve(pretrain_cfg, schema=ConfigSchemaPretrain) - init_tok2vec(self, P, I) if self.vocab.vectors.data.shape[1] >= 1: ops = get_current_ops() self.vocab.vectors.data = ops.asarray(self.vocab.vectors.data) @@ -1244,6 +1240,10 @@ class Language: proc.initialize, p_settings, section="components", name=name ) proc.initialize(get_examples, nlp=self, **p_settings) + pretrain_cfg = config.get("pretraining") + if pretrain_cfg: + P = registry.resolve(pretrain_cfg, schema=ConfigSchemaPretrain) + init_tok2vec(self, P, I) self._link_components() self._optimizer = sgd if sgd is not None: @@ -1592,6 +1592,7 @@ class Language: # using the nlp.config with all defaults. config = util.copy_config(config) orig_pipeline = config.pop("components", {}) + orig_pretraining = config.pop("pretraining", None) config["components"] = {} if auto_fill: filled = registry.fill(config, validate=validate, schema=ConfigSchema) @@ -1599,6 +1600,9 @@ class Language: filled = config filled["components"] = orig_pipeline config["components"] = orig_pipeline + if orig_pretraining is not None: + filled["pretraining"] = orig_pretraining + config["pretraining"] = orig_pretraining resolved_nlp = registry.resolve( filled["nlp"], validate=validate, schema=ConfigSchemaNlp ) diff --git a/spacy/ml/models/multi_task.py b/spacy/ml/models/multi_task.py index 8aa0f3c2b..cbfa59eea 100644 --- a/spacy/ml/models/multi_task.py +++ b/spacy/ml/models/multi_task.py @@ -21,6 +21,8 @@ def create_pretrain_vectors( maxout_pieces: int, hidden_size: int, loss: str ) -> Callable[["Vocab", Model], Model]: def create_vectors_objective(vocab: "Vocab", tok2vec: Model) -> Model: + if vocab.vectors.data.shape[1] == 0: + raise ValueError(Errors.E875) model = build_cloze_multi_task_model( vocab, tok2vec, hidden_size=hidden_size, maxout_pieces=maxout_pieces ) @@ -134,7 +136,7 @@ def build_cloze_characters_multi_task_model( ) -> Model: output_layer = chain( list2array(), - Maxout(hidden_size, nP=maxout_pieces), + Maxout(nO=hidden_size, nP=maxout_pieces), LayerNorm(nI=hidden_size), MultiSoftmax([256] * nr_char, nI=hidden_size), ) diff --git a/spacy/tests/serialize/test_serialize_config.py b/spacy/tests/serialize/test_serialize_config.py index 86f726c43..66b66b744 100644 --- a/spacy/tests/serialize/test_serialize_config.py +++ b/spacy/tests/serialize/test_serialize_config.py @@ -293,7 +293,7 @@ def test_serialize_parser(parser_config_string): def test_config_nlp_roundtrip(): - """Test that a config prduced by the nlp object passes training config + """Test that a config produced by the nlp object passes training config validation.""" nlp = English() nlp.add_pipe("entity_ruler") diff --git a/spacy/tests/test_cli.py b/spacy/tests/test_cli.py index a3834f31a..c36be9c57 100644 --- a/spacy/tests/test_cli.py +++ b/spacy/tests/test_cli.py @@ -4,7 +4,7 @@ from spacy.training import docs_to_json, offsets_to_biluo_tags from spacy.training.converters import iob_to_docs, conll_ner_to_docs, conllu_to_docs from spacy.schemas import ProjectConfigSchema, RecommendationSchema, validate from spacy.lang.nl import Dutch -from spacy.util import ENV_VARS +from spacy.util import ENV_VARS, load_model_from_config from spacy.cli import info from spacy.cli.init_config import init_config, RECOMMENDATIONS from spacy.cli._util import validate_project_commands, parse_config_overrides @@ -397,10 +397,14 @@ def test_parse_cli_overrides(): "pipeline", [["tagger", "parser", "ner"], [], ["ner", "textcat", "sentencizer"]] ) @pytest.mark.parametrize("optimize", ["efficiency", "accuracy"]) -def test_init_config(lang, pipeline, optimize): +@pytest.mark.parametrize("pretraining", [True, False]) +def test_init_config(lang, pipeline, optimize, pretraining): # TODO: add more tests and also check for GPU with transformers - config = init_config(lang=lang, pipeline=pipeline, optimize=optimize, gpu=False) + config = init_config(lang=lang, pipeline=pipeline, optimize=optimize, pretraining=pretraining, gpu=False) assert isinstance(config, Config) + if pretraining: + config["paths"]["raw_text"] = "my_data.jsonl" + nlp = load_model_from_config(config, auto_fill=True) def test_model_recommendations(): diff --git a/spacy/tests/training/test_pretraining.py b/spacy/tests/training/test_pretraining.py new file mode 100644 index 000000000..bd8810a5c --- /dev/null +++ b/spacy/tests/training/test_pretraining.py @@ -0,0 +1,345 @@ +from pathlib import Path +import numpy as np +import pytest +import srsly +from spacy.vocab import Vocab +from thinc.api import Config + +from ..util import make_tempdir +from ... import util +from ...lang.en import English +from ...training.initialize import init_nlp +from ...training.loop import train +from ...training.pretrain import pretrain +from ...tokens import Doc, DocBin +from ...language import DEFAULT_CONFIG_PRETRAIN_PATH, DEFAULT_CONFIG_PATH + +pretrain_string_listener = """ +[nlp] +lang = "en" +pipeline = ["tok2vec", "tagger"] + +[components] + +[components.tok2vec] +factory = "tok2vec" + +[components.tok2vec.model] +@architectures = "spacy.HashEmbedCNN.v1" +pretrained_vectors = null +width = 342 +depth = 4 +window_size = 1 +embed_size = 2000 +maxout_pieces = 3 +subword_features = true + +[components.tagger] +factory = "tagger" + +[components.tagger.model] +@architectures = "spacy.Tagger.v1" + +[components.tagger.model.tok2vec] +@architectures = "spacy.Tok2VecListener.v1" +width = ${components.tok2vec.model.width} + +[pretraining] +max_epochs = 5 + +[training] +max_epochs = 5 +""" + +pretrain_string_internal = """ +[nlp] +lang = "en" +pipeline = ["tagger"] + +[components] + +[components.tagger] +factory = "tagger" + +[components.tagger.model] +@architectures = "spacy.Tagger.v1" + +[components.tagger.model.tok2vec] +@architectures = "spacy.HashEmbedCNN.v1" +pretrained_vectors = null +width = 342 +depth = 4 +window_size = 1 +embed_size = 2000 +maxout_pieces = 3 +subword_features = true + +[pretraining] +max_epochs = 5 + +[training] +max_epochs = 5 +""" + + +pretrain_string_vectors = """ +[nlp] +lang = "en" +pipeline = ["tok2vec", "tagger"] + +[components] + +[components.tok2vec] +factory = "tok2vec" + +[components.tok2vec.model] +@architectures = "spacy.HashEmbedCNN.v1" +pretrained_vectors = null +width = 342 +depth = 4 +window_size = 1 +embed_size = 2000 +maxout_pieces = 3 +subword_features = true + +[components.tagger] +factory = "tagger" + +[components.tagger.model] +@architectures = "spacy.Tagger.v1" + +[components.tagger.model.tok2vec] +@architectures = "spacy.Tok2VecListener.v1" +width = ${components.tok2vec.model.width} + +[pretraining] +max_epochs = 5 + +[pretraining.objective] +@architectures = spacy.PretrainVectors.v1 +maxout_pieces = 3 +hidden_size = 300 +loss = cosine + +[training] +max_epochs = 5 +""" + +CHAR_OBJECTIVES = [ + {}, + {"@architectures": "spacy.PretrainCharacters.v1"}, + { + "@architectures": "spacy.PretrainCharacters.v1", + "maxout_pieces": 5, + "hidden_size": 42, + "n_characters": 2, + }, +] + +VECTOR_OBJECTIVES = [ + { + "@architectures": "spacy.PretrainVectors.v1", + "maxout_pieces": 3, + "hidden_size": 300, + "loss": "cosine", + }, + { + "@architectures": "spacy.PretrainVectors.v1", + "maxout_pieces": 2, + "hidden_size": 200, + "loss": "L2", + }, +] + + +def test_pretraining_default(): + """Test that pretraining defaults to a character objective""" + config = Config().from_str(pretrain_string_internal) + nlp = util.load_model_from_config(config, auto_fill=True, validate=False) + filled = nlp.config + pretrain_config = util.load_config(DEFAULT_CONFIG_PRETRAIN_PATH) + filled = pretrain_config.merge(filled) + assert "PretrainCharacters" in filled["pretraining"]["objective"]["@architectures"] + + +@pytest.mark.parametrize("objective", CHAR_OBJECTIVES) +def test_pretraining_tok2vec_characters(objective): + """Test that pretraining works with the character objective""" + config = Config().from_str(pretrain_string_listener) + config["pretraining"]["objective"] = objective + nlp = util.load_model_from_config(config, auto_fill=True, validate=False) + filled = nlp.config + pretrain_config = util.load_config(DEFAULT_CONFIG_PRETRAIN_PATH) + filled = pretrain_config.merge(filled) + with make_tempdir() as tmp_dir: + file_path = write_sample_jsonl(tmp_dir) + filled["paths"]["raw_text"] = file_path + filled = filled.interpolate() + assert filled["pretraining"]["component"] == "tok2vec" + pretrain(filled, tmp_dir) + assert Path(tmp_dir / "model0.bin").exists() + assert Path(tmp_dir / "model4.bin").exists() + assert not Path(tmp_dir / "model5.bin").exists() + + +@pytest.mark.parametrize("objective", VECTOR_OBJECTIVES) +def test_pretraining_tok2vec_vectors_fail(objective): + """Test that pretraining doesn't works with the vectors objective if there are no static vectors""" + config = Config().from_str(pretrain_string_listener) + config["pretraining"]["objective"] = objective + nlp = util.load_model_from_config(config, auto_fill=True, validate=False) + filled = nlp.config + pretrain_config = util.load_config(DEFAULT_CONFIG_PRETRAIN_PATH) + filled = pretrain_config.merge(filled) + with make_tempdir() as tmp_dir: + file_path = write_sample_jsonl(tmp_dir) + filled["paths"]["raw_text"] = file_path + filled = filled.interpolate() + assert filled["initialize"]["vectors"] is None + with pytest.raises(ValueError): + pretrain(filled, tmp_dir) + + +@pytest.mark.parametrize("objective", VECTOR_OBJECTIVES) +def test_pretraining_tok2vec_vectors(objective): + """Test that pretraining works with the vectors objective and static vectors defined""" + config = Config().from_str(pretrain_string_listener) + config["pretraining"]["objective"] = objective + nlp = util.load_model_from_config(config, auto_fill=True, validate=False) + filled = nlp.config + pretrain_config = util.load_config(DEFAULT_CONFIG_PRETRAIN_PATH) + filled = pretrain_config.merge(filled) + with make_tempdir() as tmp_dir: + file_path = write_sample_jsonl(tmp_dir) + filled["paths"]["raw_text"] = file_path + nlp_path = write_vectors_model(tmp_dir) + filled["initialize"]["vectors"] = nlp_path + filled = filled.interpolate() + pretrain(filled, tmp_dir) + + +@pytest.mark.parametrize("config", [pretrain_string_internal, pretrain_string_listener]) +def test_pretraining_tagger_tok2vec(config): + """Test pretraining of the tagger's tok2vec layer (via a listener)""" + config = Config().from_str(pretrain_string_listener) + nlp = util.load_model_from_config(config, auto_fill=True, validate=False) + filled = nlp.config + pretrain_config = util.load_config(DEFAULT_CONFIG_PRETRAIN_PATH) + filled = pretrain_config.merge(filled) + with make_tempdir() as tmp_dir: + file_path = write_sample_jsonl(tmp_dir) + filled["paths"]["raw_text"] = file_path + filled["pretraining"]["component"] = "tagger" + filled["pretraining"]["layer"] = "tok2vec" + filled = filled.interpolate() + pretrain(filled, tmp_dir) + assert Path(tmp_dir / "model0.bin").exists() + assert Path(tmp_dir / "model4.bin").exists() + assert not Path(tmp_dir / "model5.bin").exists() + + +def test_pretraining_tagger(): + """Test pretraining of the tagger itself will throw an error (not an appropriate tok2vec layer)""" + config = Config().from_str(pretrain_string_internal) + nlp = util.load_model_from_config(config, auto_fill=True, validate=False) + filled = nlp.config + pretrain_config = util.load_config(DEFAULT_CONFIG_PRETRAIN_PATH) + filled = pretrain_config.merge(filled) + with make_tempdir() as tmp_dir: + file_path = write_sample_jsonl(tmp_dir) + filled["paths"]["raw_text"] = file_path + filled["pretraining"]["component"] = "tagger" + filled = filled.interpolate() + with pytest.raises(ValueError): + pretrain(filled, tmp_dir) + + +def test_pretraining_training(): + """Test that training can use a pretrained Tok2Vec model""" + config = Config().from_str(pretrain_string_internal) + nlp = util.load_model_from_config(config, auto_fill=True, validate=False) + filled = nlp.config + pretrain_config = util.load_config(DEFAULT_CONFIG_PRETRAIN_PATH) + filled = pretrain_config.merge(filled) + train_config = util.load_config(DEFAULT_CONFIG_PATH) + filled = train_config.merge(filled) + with make_tempdir() as tmp_dir: + pretrain_dir = tmp_dir / "pretrain" + pretrain_dir.mkdir() + file_path = write_sample_jsonl(pretrain_dir) + filled["paths"]["raw_text"] = file_path + filled["pretraining"]["component"] = "tagger" + filled["pretraining"]["layer"] = "tok2vec" + train_dir = tmp_dir / "train" + train_dir.mkdir() + train_path, dev_path = write_sample_training(train_dir) + filled["paths"]["train"] = train_path + filled["paths"]["dev"] = dev_path + filled = filled.interpolate() + P = filled["pretraining"] + nlp_base = init_nlp(filled) + model_base = nlp_base.get_pipe(P["component"]).model.get_ref(P["layer"]).get_ref("embed") + embed_base = None + for node in model_base.walk(): + if node.name == "hashembed": + embed_base = node + pretrain(filled, pretrain_dir) + pretrained_model = Path(pretrain_dir / "model3.bin") + assert pretrained_model.exists() + filled["initialize"]["init_tok2vec"] = str(pretrained_model) + nlp = init_nlp(filled) + model = nlp.get_pipe(P["component"]).model.get_ref(P["layer"]).get_ref("embed") + embed = None + for node in model.walk(): + if node.name == "hashembed": + embed = node + # ensure that the tok2vec weights are actually changed by the pretraining + assert np.any(np.not_equal(embed.get_param("E"), embed_base.get_param("E"))) + train(nlp, train_dir) + + +def write_sample_jsonl(tmp_dir): + data = [ + { + "meta": {"id": "1"}, + "text": "This is the best TV you'll ever buy!", + "cats": {"pos": 1, "neg": 0}, + }, + { + "meta": {"id": "2"}, + "text": "I wouldn't buy this again.", + "cats": {"pos": 0, "neg": 1}, + }, + ] + file_path = f"{tmp_dir}/text.jsonl" + srsly.write_jsonl(file_path, data) + return file_path + + +def write_sample_training(tmp_dir): + words = ["The", "players", "start", "."] + tags = ["DT", "NN", "VBZ", "."] + doc = Doc(English().vocab, words=words, tags=tags) + doc_bin = DocBin() + doc_bin.add(doc) + train_path = f"{tmp_dir}/train.spacy" + dev_path = f"{tmp_dir}/dev.spacy" + doc_bin.to_disk(train_path) + doc_bin.to_disk(dev_path) + return train_path, dev_path + + +def write_vectors_model(tmp_dir): + import numpy + vocab = Vocab() + vector_data = { + "dog": numpy.random.uniform(-1, 1, (300,)), + "cat": numpy.random.uniform(-1, 1, (300,)), + "orange": numpy.random.uniform(-1, 1, (300,)) + } + for word, vector in vector_data.items(): + vocab.set_vector(word, vector) + nlp_path = tmp_dir / "vectors_model" + nlp = English(vocab) + nlp.to_disk(nlp_path) + return str(nlp_path) diff --git a/spacy/training/initialize.py b/spacy/training/initialize.py index 25bb73c78..f7f2f21a4 100644 --- a/spacy/training/initialize.py +++ b/spacy/training/initialize.py @@ -9,6 +9,7 @@ import gzip import zipfile import tqdm +from .pretrain import get_tok2vec_ref from ..lookups import Lookups from ..vectors import Vectors from ..errors import Errors, Warnings @@ -147,10 +148,6 @@ def init_tok2vec( weights_data = None init_tok2vec = ensure_path(I["init_tok2vec"]) if init_tok2vec is not None: - if P["objective"].get("type") == "vectors" and not I["vectors"]: - err = 'need initialize.vectors if pretraining.objective.type is "vectors"' - errors = [{"loc": ["initialize"], "msg": err}] - raise ConfigValidationError(config=nlp.config, errors=errors) if not init_tok2vec.exists(): err = f"can't find pretrained tok2vec: {init_tok2vec}" errors = [{"loc": ["initialize", "init_tok2vec"], "msg": err}] @@ -158,21 +155,9 @@ def init_tok2vec( with init_tok2vec.open("rb") as file_: weights_data = file_.read() if weights_data is not None: - tok2vec_component = P["component"] - if tok2vec_component is None: - desc = ( - f"To use pretrained tok2vec weights, [pretraining.component] " - f"needs to specify the component that should load them." - ) - err = "component can't be null" - errors = [{"loc": ["pretraining", "component"], "msg": err}] - raise ConfigValidationError( - config=nlp.config["pretraining"], errors=errors, desc=desc - ) - layer = nlp.get_pipe(tok2vec_component).model - if P["layer"]: - layer = layer.get_ref(P["layer"]) + layer = get_tok2vec_ref(nlp, P) layer.from_bytes(weights_data) + logger.info(f"Loaded pretrained weights from {init_tok2vec}") return True return False diff --git a/spacy/training/pretrain.py b/spacy/training/pretrain.py index 152d849e9..c791732db 100644 --- a/spacy/training/pretrain.py +++ b/spacy/training/pretrain.py @@ -6,9 +6,12 @@ from collections import Counter import srsly import time import re + +from thinc.config import ConfigValidationError from wasabi import Printer from .example import Example +from ..errors import Errors from ..tokens import Doc from ..schemas import ConfigSchemaPretrain from ..util import registry, load_model_from_config, dot_to_object @@ -133,12 +136,21 @@ def create_pretraining_model(nlp, pretrain_config): The actual tok2vec layer is stored as a reference, and only this bit will be serialized to file and read back in when calling the 'train' command. """ - nlp.initialize() - component = nlp.get_pipe(pretrain_config["component"]) - if pretrain_config.get("layer"): - tok2vec = component.model.get_ref(pretrain_config["layer"]) - else: - tok2vec = component.model + with nlp.select_pipes(enable=[]): + nlp.initialize() + tok2vec = get_tok2vec_ref(nlp, pretrain_config) + # If the config referred to a Tok2VecListener, grab the original model instead + if type(tok2vec).__name__ == "Tok2VecListener": + original_tok2vec = ( + tok2vec.upstream_name if tok2vec.upstream_name is not "*" else "tok2vec" + ) + tok2vec = nlp.get_pipe(original_tok2vec).model + try: + tok2vec.initialize(X=[nlp.make_doc("Give it a doc to infer shapes")]) + except ValueError: + component = pretrain_config["component"] + layer = pretrain_config["layer"] + raise ValueError(Errors.E874.format(component=component, layer=layer)) create_function = pretrain_config["objective"] model = create_function(nlp.vocab, tok2vec) @@ -147,6 +159,24 @@ def create_pretraining_model(nlp, pretrain_config): return model +def get_tok2vec_ref(nlp, pretrain_config): + tok2vec_component = pretrain_config["component"] + if tok2vec_component is None: + desc = ( + f"To use pretrained tok2vec weights, [pretraining.component] " + f"needs to specify the component that should load them." + ) + err = "component can't be null" + errors = [{"loc": ["pretraining", "component"], "msg": err}] + raise ConfigValidationError( + config=nlp.config["pretraining"], errors=errors, desc=desc + ) + layer = nlp.get_pipe(tok2vec_component).model + if pretrain_config["layer"]: + layer = layer.get_ref(pretrain_config["layer"]) + return layer + + class ProgressTracker: def __init__(self, frequency=1000000): self.loss = 0.0 diff --git a/website/docs/api/architectures.md b/website/docs/api/architectures.md index 1739836ed..793855d18 100644 --- a/website/docs/api/architectures.md +++ b/website/docs/api/architectures.md @@ -447,6 +447,9 @@ For more information, see the section on > ```ini > [pretraining] > component = "tok2vec" +> +> [initialize] +> vectors = "en_core_web_lg" > ... > > [pretraining.objective] @@ -457,7 +460,9 @@ For more information, see the section on > ``` Predict the word's vector from a static embeddings table as pretraining -objective for a Tok2Vec layer. +objective for a Tok2Vec layer. To use this objective, make sure that the +`initialize.vectors` section in the config refers to a model with static +vectors. | Name | Description | | --------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------- | From 3f3e8110dc6ec3c1449b120c29ffc7d5475ef622 Mon Sep 17 00:00:00 2001 From: Adriane Boyd Date: Tue, 9 Mar 2021 04:02:32 +0100 Subject: [PATCH 03/10] Fix lowercase augmentation (#7336) * Fix aborted/skipped augmentation for `spacy.orth_variants.v1` if lowercasing was enabled for an example * Simplify `spacy.orth_variants.v1` for `Example` vs. `GoldParse` * Preserve reference tokenization in `spacy.lower_case.v1` --- spacy/tests/training/test_augmenters.py | 61 ++++++++- spacy/training/augment.py | 158 +++++++++--------------- 2 files changed, 114 insertions(+), 105 deletions(-) diff --git a/spacy/tests/training/test_augmenters.py b/spacy/tests/training/test_augmenters.py index 0bd4d5ef2..43a78e4b0 100644 --- a/spacy/tests/training/test_augmenters.py +++ b/spacy/tests/training/test_augmenters.py @@ -38,19 +38,59 @@ def doc(nlp): @pytest.mark.filterwarnings("ignore::UserWarning") -def test_make_orth_variants(nlp, doc): +def test_make_orth_variants(nlp): single = [ {"tags": ["NFP"], "variants": ["…", "..."]}, {"tags": [":"], "variants": ["-", "—", "–", "--", "---", "——"]}, ] + # fmt: off + words = ["\n\n", "A", "\t", "B", "a", "b", "…", "...", "-", "—", "–", "--", "---", "——"] + tags = ["_SP", "NN", "\t", "NN", "NN", "NN", "NFP", "NFP", ":", ":", ":", ":", ":", ":"] + # fmt: on + spaces = [True] * len(words) + spaces[0] = False + spaces[2] = False + doc = Doc(nlp.vocab, words=words, spaces=spaces, tags=tags) augmenter = create_orth_variants_augmenter( level=0.2, lower=0.5, orth_variants={"single": single} ) - with make_docbin([doc]) as output_file: + with make_docbin([doc] * 10) as output_file: reader = Corpus(output_file, augmenter=augmenter) - # Due to randomness, only test that it works without errors for now + # Due to randomness, only test that it works without errors list(reader(nlp)) + # check that the following settings lowercase everything + augmenter = create_orth_variants_augmenter( + level=1.0, lower=1.0, orth_variants={"single": single} + ) + with make_docbin([doc] * 10) as output_file: + reader = Corpus(output_file, augmenter=augmenter) + for example in reader(nlp): + for token in example.reference: + assert token.text == token.text.lower() + + # check that lowercasing is applied without tags + doc = Doc(nlp.vocab, words=words, spaces=[True] * len(words)) + augmenter = create_orth_variants_augmenter( + level=1.0, lower=1.0, orth_variants={"single": single} + ) + with make_docbin([doc] * 10) as output_file: + reader = Corpus(output_file, augmenter=augmenter) + for example in reader(nlp): + for ex_token, doc_token in zip(example.reference, doc): + assert ex_token.text == doc_token.text.lower() + + # check that no lowercasing is applied with lower=0.0 + doc = Doc(nlp.vocab, words=words, spaces=[True] * len(words)) + augmenter = create_orth_variants_augmenter( + level=1.0, lower=0.0, orth_variants={"single": single} + ) + with make_docbin([doc] * 10) as output_file: + reader = Corpus(output_file, augmenter=augmenter) + for example in reader(nlp): + for ex_token, doc_token in zip(example.reference, doc): + assert ex_token.text == doc_token.text + def test_lowercase_augmenter(nlp, doc): augmenter = create_lower_casing_augmenter(level=1.0) @@ -66,6 +106,21 @@ def test_lowercase_augmenter(nlp, doc): assert ref_ent.text == orig_ent.text.lower() assert [t.pos_ for t in eg.reference] == [t.pos_ for t in doc] + # check that augmentation works when lowercasing leads to different + # predicted tokenization + words = ["A", "B", "CCC."] + doc = Doc(nlp.vocab, words=words) + with make_docbin([doc]) as output_file: + reader = Corpus(output_file, augmenter=augmenter) + corpus = list(reader(nlp)) + eg = corpus[0] + assert eg.reference.text == doc.text.lower() + assert eg.predicted.text == doc.text.lower() + assert [t.text for t in eg.reference] == [t.lower() for t in words] + assert [t.text for t in eg.predicted] == [ + t.text for t in nlp.make_doc(doc.text.lower()) + ] + @pytest.mark.filterwarnings("ignore::UserWarning") def test_custom_data_augmentation(nlp, doc): diff --git a/spacy/training/augment.py b/spacy/training/augment.py index 13ae45bd2..0dae92143 100644 --- a/spacy/training/augment.py +++ b/spacy/training/augment.py @@ -1,12 +1,10 @@ from typing import Callable, Iterator, Dict, List, Tuple, TYPE_CHECKING import random import itertools -import copy from functools import partial from pydantic import BaseModel, StrictStr from ..util import registry -from ..tokens import Doc from .example import Example if TYPE_CHECKING: @@ -71,7 +69,7 @@ def lower_casing_augmenter( else: example_dict = example.to_dict() doc = nlp.make_doc(example.text.lower()) - example_dict["token_annotation"]["ORTH"] = [t.lower_ for t in doc] + example_dict["token_annotation"]["ORTH"] = [t.lower_ for t in example.reference] yield example.from_dict(doc, example_dict) @@ -88,24 +86,15 @@ def orth_variants_augmenter( else: raw_text = example.text orig_dict = example.to_dict() - if not orig_dict["token_annotation"]: - yield example - else: - variant_text, variant_token_annot = make_orth_variants( - nlp, - raw_text, - orig_dict["token_annotation"], - orth_variants, - lower=raw_text is not None and random.random() < lower, - ) - if variant_text: - doc = nlp.make_doc(variant_text) - else: - doc = Doc(nlp.vocab, words=variant_token_annot["ORTH"]) - variant_token_annot["ORTH"] = [w.text for w in doc] - variant_token_annot["SPACY"] = [w.whitespace_ for w in doc] - orig_dict["token_annotation"] = variant_token_annot - yield example.from_dict(doc, orig_dict) + variant_text, variant_token_annot = make_orth_variants( + nlp, + raw_text, + orig_dict["token_annotation"], + orth_variants, + lower=raw_text is not None and random.random() < lower, + ) + orig_dict["token_annotation"] = variant_token_annot + yield example.from_dict(nlp.make_doc(variant_text), orig_dict) def make_orth_variants( @@ -116,88 +105,53 @@ def make_orth_variants( *, lower: bool = False, ) -> Tuple[str, Dict[str, List[str]]]: - orig_token_dict = copy.deepcopy(token_dict) - ndsv = orth_variants.get("single", []) - ndpv = orth_variants.get("paired", []) words = token_dict.get("ORTH", []) tags = token_dict.get("TAG", []) - # keep unmodified if words or tags are not defined - if words and tags: - if lower: - words = [w.lower() for w in words] - # single variants - punct_choices = [random.choice(x["variants"]) for x in ndsv] - for word_idx in range(len(words)): - for punct_idx in range(len(ndsv)): - if ( - tags[word_idx] in ndsv[punct_idx]["tags"] - and words[word_idx] in ndsv[punct_idx]["variants"] - ): - words[word_idx] = punct_choices[punct_idx] - # paired variants - punct_choices = [random.choice(x["variants"]) for x in ndpv] - for word_idx in range(len(words)): - for punct_idx in range(len(ndpv)): - if tags[word_idx] in ndpv[punct_idx]["tags"] and words[ - word_idx - ] in itertools.chain.from_iterable(ndpv[punct_idx]["variants"]): - # backup option: random left vs. right from pair - pair_idx = random.choice([0, 1]) - # best option: rely on paired POS tags like `` / '' - if len(ndpv[punct_idx]["tags"]) == 2: - pair_idx = ndpv[punct_idx]["tags"].index(tags[word_idx]) - # next best option: rely on position in variants - # (may not be unambiguous, so order of variants matters) - else: - for pair in ndpv[punct_idx]["variants"]: - if words[word_idx] in pair: - pair_idx = pair.index(words[word_idx]) - words[word_idx] = punct_choices[punct_idx][pair_idx] + # keep unmodified if words are not defined + if not words: + return raw, token_dict + if lower: + words = [w.lower() for w in words] + raw = raw.lower() + # if no tags, only lowercase + if not tags: token_dict["ORTH"] = words - token_dict["TAG"] = tags - # modify raw - if raw is not None: - variants = [] - for single_variants in ndsv: - variants.extend(single_variants["variants"]) - for paired_variants in ndpv: - variants.extend( - list(itertools.chain.from_iterable(paired_variants["variants"])) - ) - # store variants in reverse length order to be able to prioritize - # longer matches (e.g., "---" before "--") - variants = sorted(variants, key=lambda x: len(x)) - variants.reverse() - variant_raw = "" - raw_idx = 0 - # add initial whitespace - while raw_idx < len(raw) and raw[raw_idx].isspace(): - variant_raw += raw[raw_idx] - raw_idx += 1 - for word in words: - match_found = False - # skip whitespace words - if word.isspace(): - match_found = True - # add identical word - elif word not in variants and raw[raw_idx:].startswith(word): - variant_raw += word - raw_idx += len(word) - match_found = True - # add variant word - else: - for variant in variants: - if not match_found and raw[raw_idx:].startswith(variant): - raw_idx += len(variant) - variant_raw += word - match_found = True - # something went wrong, abort - # (add a warning message?) - if not match_found: - return raw, orig_token_dict - # add following whitespace - while raw_idx < len(raw) and raw[raw_idx].isspace(): - variant_raw += raw[raw_idx] - raw_idx += 1 - raw = variant_raw + return raw, token_dict + # single variants + ndsv = orth_variants.get("single", []) + punct_choices = [random.choice(x["variants"]) for x in ndsv] + for word_idx in range(len(words)): + for punct_idx in range(len(ndsv)): + if ( + tags[word_idx] in ndsv[punct_idx]["tags"] + and words[word_idx] in ndsv[punct_idx]["variants"] + ): + words[word_idx] = punct_choices[punct_idx] + # paired variants + ndpv = orth_variants.get("paired", []) + punct_choices = [random.choice(x["variants"]) for x in ndpv] + for word_idx in range(len(words)): + for punct_idx in range(len(ndpv)): + if tags[word_idx] in ndpv[punct_idx]["tags"] and words[ + word_idx + ] in itertools.chain.from_iterable(ndpv[punct_idx]["variants"]): + # backup option: random left vs. right from pair + pair_idx = random.choice([0, 1]) + # best option: rely on paired POS tags like `` / '' + if len(ndpv[punct_idx]["tags"]) == 2: + pair_idx = ndpv[punct_idx]["tags"].index(tags[word_idx]) + # next best option: rely on position in variants + # (may not be unambiguous, so order of variants matters) + else: + for pair in ndpv[punct_idx]["variants"]: + if words[word_idx] in pair: + pair_idx = pair.index(words[word_idx]) + words[word_idx] = punct_choices[punct_idx][pair_idx] + token_dict["ORTH"] = words + # construct modified raw text from words and spaces + raw = "" + for orth, spacy in zip(token_dict["ORTH"], token_dict["SPACY"]): + raw += orth + if spacy: + raw += " " return raw, token_dict From f26b61e0015a144c5b97d5087c9cb7486f55dbfc Mon Sep 17 00:00:00 2001 From: Jan Krepl Date: Tue, 9 Mar 2021 10:49:53 +0100 Subject: [PATCH 04/10] Make sure sorted --- spacy/pipeline/entityruler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spacy/pipeline/entityruler.py b/spacy/pipeline/entityruler.py index 25bc3abee..4e61dbca7 100644 --- a/spacy/pipeline/entityruler.py +++ b/spacy/pipeline/entityruler.py @@ -195,7 +195,7 @@ class EntityRuler(Pipe): all_labels.add(label) else: all_labels.add(l) - return tuple(all_labels) + return tuple(sorted(all_labels)) def initialize( self, From 0e1d579f0c65dbcbde05e3430e64c299d58205df Mon Sep 17 00:00:00 2001 From: Jan Krepl Date: Tue, 9 Mar 2021 10:57:32 +0100 Subject: [PATCH 05/10] Add agreement --- .github/contributors/jankrepl.md | 106 +++++++++++++++++++++++++++++++ 1 file changed, 106 insertions(+) create mode 100644 .github/contributors/jankrepl.md diff --git a/.github/contributors/jankrepl.md b/.github/contributors/jankrepl.md new file mode 100644 index 000000000..eda5a29b8 --- /dev/null +++ b/.github/contributors/jankrepl.md @@ -0,0 +1,106 @@ +# spaCy contributor agreement + +This spaCy Contributor Agreement (**"SCA"**) is based on the +[Oracle Contributor Agreement](http://www.oracle.com/technetwork/oca-405177.pdf). +The SCA applies to any contribution that you make to any product or project +managed by us (the **"project"**), and sets out the intellectual property rights +you grant to us in the contributed materials. The term **"us"** shall mean +[ExplosionAI GmbH](https://explosion.ai/legal). The term +**"you"** shall mean the person or entity identified below. + +If you agree to be bound by these terms, fill in the information requested +below and include the filled-in version with your first pull request, under the +folder [`.github/contributors/`](/.github/contributors/). The name of the file +should be your GitHub username, with the extension `.md`. For example, the user +example_user would create the file `.github/contributors/example_user.md`. + +Read this agreement carefully before signing. These terms and conditions +constitute a binding legal agreement. + +## Contributor Agreement + +1. The term "contribution" or "contributed materials" means any source code, +object code, patch, tool, sample, graphic, specification, manual, +documentation, or any other material posted or submitted by you to the project. + +2. With respect to any worldwide copyrights, or copyright applications and +registrations, in your contribution: + + * you hereby assign to us joint ownership, and to the extent that such + assignment is or becomes invalid, ineffective or unenforceable, you hereby + grant to us a perpetual, irrevocable, non-exclusive, worldwide, no-charge, + royalty-free, unrestricted license to exercise all rights under those + copyrights. This includes, at our option, the right to sublicense these same + rights to third parties through multiple levels of sublicensees or other + licensing arrangements; + + * you agree that each of us can do all things in relation to your + contribution as if each of us were the sole owners, and if one of us makes + a derivative work of your contribution, the one who makes the derivative + work (or has it made will be the sole owner of that derivative work; + + * you agree that you will not assert any moral rights in your contribution + against us, our licensees or transferees; + + * you agree that we may register a copyright in your contribution and + exercise all ownership rights associated with it; and + + * you agree that neither of us has any duty to consult with, obtain the + consent of, pay or render an accounting to the other for any use or + distribution of your contribution. + +3. With respect to any patents you own, or that you can license without payment +to any third party, you hereby grant to us a perpetual, irrevocable, +non-exclusive, worldwide, no-charge, royalty-free license to: + + * make, have made, use, sell, offer to sell, import, and otherwise transfer + your contribution in whole or in part, alone or in combination with or + included in any product, work or materials arising out of the project to + which your contribution was submitted, and + + * at our option, to sublicense these same rights to third parties through + multiple levels of sublicensees or other licensing arrangements. + +4. Except as set out above, you keep all right, title, and interest in your +contribution. The rights that you grant to us under these terms are effective +on the date you first submitted a contribution to us, even if your submission +took place before the date you sign these terms. + +5. You covenant, represent, warrant and agree that: + + * Each contribution that you submit is and shall be an original work of + authorship and you can legally grant the rights set out in this SCA; + + * to the best of your knowledge, each contribution will not violate any + third party's copyrights, trademarks, patents, or other intellectual + property rights; and + + * each contribution shall be in compliance with U.S. export control laws and + other applicable export and import laws. You agree to notify us if you + become aware of any circumstance which would make any of the foregoing + representations inaccurate in any respect. We may publicly disclose your + participation in the project, including the fact that you have signed the SCA. + +6. This SCA is governed by the laws of the State of California and applicable +U.S. Federal law. Any choice of law rules will not apply. + +7. Please place an “x” on one of the applicable statement below. Please do NOT +mark both statements: + + * [x] I am signing on behalf of myself as an individual and no other person + or entity, including my employer, has or will have rights with respect to my + contributions. + + * [ ] I am signing on behalf of my employer or a legal entity and I have the + actual authority to contractually bind that entity. + +## Contributor Details + +| Field | Entry | +|------------------------------- | -------------------- | +| Name | Jan Krepl | +| Company name (if applicable) | | +| Title or role (if applicable) | | +| Date | 2021-03-09 | +| GitHub username | jankrepl | +| Website (optional) | | From 39de3602e0321eb2dbcfce032a8d4734162ee69d Mon Sep 17 00:00:00 2001 From: Sofie Van Landeghem Date: Tue, 9 Mar 2021 13:01:31 +0100 Subject: [PATCH 06/10] return custom error in nlp.initialize (#7104) * return custom error in nlp.initialize * Rename error Co-authored-by: Ines Montani --- spacy/errors.py | 4 ++++ spacy/language.py | 9 ++++++--- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/spacy/errors.py b/spacy/errors.py index 4f61cf098..e50a658d8 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -497,6 +497,10 @@ class Errors: E880 = ("The 'wandb' library could not be found - did you install it? " "Alternatively, specify the 'ConsoleLogger' in the 'training.logger' " "config section, instead of the 'WandbLogger'.") + E884 = ("The pipeline could not be initialized because the vectors " + "could not be found at '{vectors}'. If your pipeline was already " + "initialized/trained before, call 'resume_training' instead of 'initialize', " + "or initialize only the components that are new.") E885 = ("entity_linker.set_kb received an invalid 'kb_loader' argument: expected " "a callable function, but got: {arg_type}") E886 = ("Can't replace {name} -> {tok2vec} listeners: path '{path}' not " diff --git a/spacy/language.py b/spacy/language.py index 80de94278..5741ef97c 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -1219,9 +1219,12 @@ class Language: before_init = I["before_init"] if before_init is not None: before_init(self) - init_vocab( - self, data=I["vocab_data"], lookups=I["lookups"], vectors=I["vectors"] - ) + try: + init_vocab( + self, data=I["vocab_data"], lookups=I["lookups"], vectors=I["vectors"] + ) + except IOError: + raise IOError(Errors.E884.format(vectors=I["vectors"])) if self.vocab.vectors.data.shape[1] >= 1: ops = get_current_ops() self.vocab.vectors.data = ops.asarray(self.vocab.vectors.data) From 932887b950751020d3fb4b7f83a5a27b5512faf1 Mon Sep 17 00:00:00 2001 From: Sofie Van Landeghem Date: Tue, 9 Mar 2021 13:04:22 +0100 Subject: [PATCH 07/10] textcat scoring fix and multi_label docs (#6974) * add multi-label textcat to menu * add infobox on textcat API * add info to v3 migration guide * small edits * further fixes in doc strings * add infobox to textcat architectures * add textcat_multilabel to overview of built-in components * spelling * fix unrelated warn msg * Add textcat_multilabel to quickstart [ci skip] * remove separate documentation page for multilabel_textcategorizer * small edits * positive label clarification * avoid duplicating information in self.cfg and fix textcat.score * fix multilabel textcat too * revert threshold to storage in cfg * revert threshold stuff for multi-textcat Co-authored-by: Ines Montani --- spacy/cli/download.py | 2 +- spacy/pipeline/textcat.py | 16 +- spacy/pipeline/textcat_multilabel.py | 19 +- spacy/tests/pipeline/test_textcat.py | 48 ++ website/docs/api/architectures.md | 11 + .../docs/api/multilabel_textcategorizer.md | 453 ------------------ website/docs/api/textcategorizer.md | 66 ++- website/docs/usage/processing-pipelines.md | 37 +- website/docs/usage/v3.md | 21 +- website/src/widgets/quickstart-training.js | 38 +- 10 files changed, 191 insertions(+), 520 deletions(-) delete mode 100644 website/docs/api/multilabel_textcategorizer.md diff --git a/spacy/cli/download.py b/spacy/cli/download.py index dbda8578a..d09d5147a 100644 --- a/spacy/cli/download.py +++ b/spacy/cli/download.py @@ -60,7 +60,7 @@ def download(model: str, direct: bool = False, sdist: bool = False, *pip_args) - model_name = model if model in OLD_MODEL_SHORTCUTS: msg.warn( - f"As of spaCy v3.0, shortcuts like '{model}' are deprecated. Please" + f"As of spaCy v3.0, shortcuts like '{model}' are deprecated. Please " f"use the full pipeline package name '{OLD_MODEL_SHORTCUTS[model]}' instead." ) model_name = OLD_MODEL_SHORTCUTS[model] diff --git a/spacy/pipeline/textcat.py b/spacy/pipeline/textcat.py index f94bde84f..174ffd273 100644 --- a/spacy/pipeline/textcat.py +++ b/spacy/pipeline/textcat.py @@ -88,11 +88,9 @@ subword_features = true def make_textcat( nlp: Language, name: str, model: Model[List[Doc], List[Floats2d]], threshold: float ) -> "TextCategorizer": - """Create a TextCategorizer compoment. The text categorizer predicts categories - over a whole document. It can learn one or more labels, and the labels can - be mutually exclusive (i.e. one true label per doc) or non-mutually exclusive - (i.e. zero or more labels may be true per doc). The multi-label setting is - controlled by the model instance that's provided. + """Create a TextCategorizer component. The text categorizer predicts categories + over a whole document. It can learn one or more labels, and the labels are considered + to be mutually exclusive (i.e. one true label per doc). model (Model[List[Doc], List[Floats2d]]): A model instance that predicts scores for each category. @@ -317,9 +315,11 @@ class TextCategorizer(TrainablePipe): get_examples (Callable[[], Iterable[Example]]): Function that returns a representative sample of gold-standard Example objects. nlp (Language): The current nlp object the component is part of. - labels: The labels to add to the component, typically generated by the + labels (Optional[Iterable[str]]): The labels to add to the component, typically generated by the `init labels` command. If no labels are provided, the get_examples callback is used to extract the labels from the data. + positive_label (Optional[str]): The positive label for a binary task with exclusive classes, + `None` otherwise and by default. DOCS: https://spacy.io/api/textcategorizer#initialize """ @@ -358,13 +358,13 @@ class TextCategorizer(TrainablePipe): """ validate_examples(examples, "TextCategorizer.score") self._validate_categories(examples) + kwargs.setdefault("threshold", self.cfg["threshold"]) + kwargs.setdefault("positive_label", self.cfg["positive_label"]) return Scorer.score_cats( examples, "cats", labels=self.labels, multi_label=False, - positive_label=self.cfg["positive_label"], - threshold=self.cfg["threshold"], **kwargs, ) diff --git a/spacy/pipeline/textcat_multilabel.py b/spacy/pipeline/textcat_multilabel.py index dc4b17940..036bc8dc5 100644 --- a/spacy/pipeline/textcat_multilabel.py +++ b/spacy/pipeline/textcat_multilabel.py @@ -88,11 +88,10 @@ subword_features = true def make_multilabel_textcat( nlp: Language, name: str, model: Model[List[Doc], List[Floats2d]], threshold: float ) -> "TextCategorizer": - """Create a TextCategorizer compoment. The text categorizer predicts categories - over a whole document. It can learn one or more labels, and the labels can - be mutually exclusive (i.e. one true label per doc) or non-mutually exclusive - (i.e. zero or more labels may be true per doc). The multi-label setting is - controlled by the model instance that's provided. + """Create a TextCategorizer component. The text categorizer predicts categories + over a whole document. It can learn one or more labels, and the labels are considered + to be non-mutually exclusive, which means that there can be zero or more labels + per doc). model (Model[List[Doc], List[Floats2d]]): A model instance that predicts scores for each category. @@ -104,7 +103,7 @@ def make_multilabel_textcat( class MultiLabel_TextCategorizer(TextCategorizer): """Pipeline component for multi-label text classification. - DOCS: https://spacy.io/api/multilabel_textcategorizer + DOCS: https://spacy.io/api/textcategorizer """ def __init__( @@ -123,7 +122,7 @@ class MultiLabel_TextCategorizer(TextCategorizer): losses during training. threshold (float): Cutoff to consider a prediction "positive". - DOCS: https://spacy.io/api/multilabel_textcategorizer#init + DOCS: https://spacy.io/api/textcategorizer#init """ self.vocab = vocab self.model = model @@ -149,7 +148,7 @@ class MultiLabel_TextCategorizer(TextCategorizer): `init labels` command. If no labels are provided, the get_examples callback is used to extract the labels from the data. - DOCS: https://spacy.io/api/multilabel_textcategorizer#initialize + DOCS: https://spacy.io/api/textcategorizer#initialize """ validate_get_examples(get_examples, "MultiLabel_TextCategorizer.initialize") if labels is None: @@ -173,15 +172,15 @@ class MultiLabel_TextCategorizer(TextCategorizer): examples (Iterable[Example]): The examples to score. RETURNS (Dict[str, Any]): The scores, produced by Scorer.score_cats. - DOCS: https://spacy.io/api/multilabel_textcategorizer#score + DOCS: https://spacy.io/api/textcategorizer#score """ validate_examples(examples, "MultiLabel_TextCategorizer.score") + kwargs.setdefault("threshold", self.cfg["threshold"]) return Scorer.score_cats( examples, "cats", labels=self.labels, multi_label=True, - threshold=self.cfg["threshold"], **kwargs, ) diff --git a/spacy/tests/pipeline/test_textcat.py b/spacy/tests/pipeline/test_textcat.py index 2b01a9cc8..61af16eb5 100644 --- a/spacy/tests/pipeline/test_textcat.py +++ b/spacy/tests/pipeline/test_textcat.py @@ -370,3 +370,51 @@ def test_textcat_evaluation(): assert scores["cats_micro_p"] == 4 / 5 assert scores["cats_micro_r"] == 4 / 6 + + +def test_textcat_threshold(): + # Ensure the scorer can be called with a different threshold + nlp = English() + nlp.add_pipe("textcat") + + train_examples = [] + for text, annotations in TRAIN_DATA_SINGLE_LABEL: + train_examples.append(Example.from_dict(nlp.make_doc(text), annotations)) + nlp.initialize(get_examples=lambda: train_examples) + + # score the model (it's not actually trained but that doesn't matter) + scores = nlp.evaluate(train_examples) + assert 0 <= scores["cats_score"] <= 1 + + scores = nlp.evaluate(train_examples, scorer_cfg={"threshold": 1.0}) + assert scores["cats_f_per_type"]["POSITIVE"]["r"] == 0 + + scores = nlp.evaluate(train_examples, scorer_cfg={"threshold": 0}) + macro_f = scores["cats_score"] + assert scores["cats_f_per_type"]["POSITIVE"]["r"] == 1.0 + + scores = nlp.evaluate(train_examples, scorer_cfg={"threshold": 0, "positive_label": "POSITIVE"}) + pos_f = scores["cats_score"] + assert scores["cats_f_per_type"]["POSITIVE"]["r"] == 1.0 + assert pos_f > macro_f + + +def test_textcat_multi_threshold(): + # Ensure the scorer can be called with a different threshold + nlp = English() + nlp.add_pipe("textcat_multilabel") + + train_examples = [] + for text, annotations in TRAIN_DATA_SINGLE_LABEL: + train_examples.append(Example.from_dict(nlp.make_doc(text), annotations)) + nlp.initialize(get_examples=lambda: train_examples) + + # score the model (it's not actually trained but that doesn't matter) + scores = nlp.evaluate(train_examples) + assert 0 <= scores["cats_score"] <= 1 + + scores = nlp.evaluate(train_examples, scorer_cfg={"threshold": 1.0}) + assert scores["cats_f_per_type"]["POSITIVE"]["r"] == 0 + + scores = nlp.evaluate(train_examples, scorer_cfg={"threshold": 0}) + assert scores["cats_f_per_type"]["POSITIVE"]["r"] == 1.0 diff --git a/website/docs/api/architectures.md b/website/docs/api/architectures.md index 793855d18..9b099d8e2 100644 --- a/website/docs/api/architectures.md +++ b/website/docs/api/architectures.md @@ -589,6 +589,17 @@ several different built-in architectures. It is recommended to experiment with different architectures and settings to determine what works best on your specific data and challenge. + + +When the architecture for a text classification challenge contains a setting for +`exclusive_classes`, it is important to use the correct value for the correct +pipeline component. The `textcat` component should always be used for +single-label use-cases where `exclusive_classes = true`, while the +`textcat_multilabel` should be used for multi-label settings with +`exclusive_classes = false`. + + + ### spacy.TextCatEnsemble.v2 {#TextCatEnsemble} > #### Example Config diff --git a/website/docs/api/multilabel_textcategorizer.md b/website/docs/api/multilabel_textcategorizer.md deleted file mode 100644 index 6e1a627c6..000000000 --- a/website/docs/api/multilabel_textcategorizer.md +++ /dev/null @@ -1,453 +0,0 @@ ---- -title: Multi-label TextCategorizer -tag: class -source: spacy/pipeline/textcat_multilabel.py -new: 3 -teaser: 'Pipeline component for multi-label text classification' -api_base_class: /api/pipe -api_string_name: textcat_multilabel -api_trainable: true ---- - -The text categorizer predicts **categories over a whole document**. It -learns non-mutually exclusive labels, which means that zero or more labels -may be true per document. - -## Config and implementation {#config} - -The default config is defined by the pipeline component factory and describes -how the component should be configured. You can override its settings via the -`config` argument on [`nlp.add_pipe`](/api/language#add_pipe) or in your -[`config.cfg` for training](/usage/training#config). See the -[model architectures](/api/architectures) documentation for details on the -architectures and their arguments and hyperparameters. - -> #### Example -> -> ```python -> from spacy.pipeline.textcat_multilabel import DEFAULT_MULTI_TEXTCAT_MODEL -> config = { -> "threshold": 0.5, -> "model": DEFAULT_MULTI_TEXTCAT_MODEL, -> } -> nlp.add_pipe("textcat_multilabel", config=config) -> ``` - -| Setting | Description | -| ----------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `threshold` | Cutoff to consider a prediction "positive", relevant when printing accuracy results. ~~float~~ | -| `model` | A model instance that predicts scores for each category. Defaults to [TextCatEnsemble](/api/architectures#TextCatEnsemble). ~~Model[List[Doc], List[Floats2d]]~~ | - -```python -%%GITHUB_SPACY/spacy/pipeline/textcat_multilabel.py -``` - -## MultiLabel_TextCategorizer.\_\_init\_\_ {#init tag="method"} - -> #### Example -> -> ```python -> # Construction via add_pipe with default model -> textcat = nlp.add_pipe("textcat_multilabel") -> -> # Construction via add_pipe with custom model -> config = {"model": {"@architectures": "my_textcat"}} -> parser = nlp.add_pipe("textcat_multilabel", config=config) -> -> # Construction from class -> from spacy.pipeline import MultiLabel_TextCategorizer -> textcat = MultiLabel_TextCategorizer(nlp.vocab, model, threshold=0.5) -> ``` - -Create a new pipeline instance. In your application, you would normally use a -shortcut for this and instantiate the component using its string name and -[`nlp.add_pipe`](/api/language#create_pipe). - -| Name | Description | -| -------------- | -------------------------------------------------------------------------------------------------------------------------- | -| `vocab` | The shared vocabulary. ~~Vocab~~ | -| `model` | The Thinc [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. ~~Model[List[Doc], List[Floats2d]]~~ | -| `name` | String name of the component instance. Used to add entries to the `losses` during training. ~~str~~ | -| _keyword-only_ | | -| `threshold` | Cutoff to consider a prediction "positive", relevant when printing accuracy results. ~~float~~ | - -## MultiLabel_TextCategorizer.\_\_call\_\_ {#call tag="method"} - -Apply the pipe to one document. The document is modified in place, and returned. -This usually happens under the hood when the `nlp` object is called on a text -and all pipeline components are applied to the `Doc` in order. Both -[`__call__`](/api/multilabel_textcategorizer#call) and [`pipe`](/api/multilabel_textcategorizer#pipe) -delegate to the [`predict`](/api/multilabel_textcategorizer#predict) and -[`set_annotations`](/api/multilabel_textcategorizer#set_annotations) methods. - -> #### Example -> -> ```python -> doc = nlp("This is a sentence.") -> textcat = nlp.add_pipe("textcat_multilabel") -> # This usually happens under the hood -> processed = textcat(doc) -> ``` - -| Name | Description | -| ----------- | -------------------------------- | -| `doc` | The document to process. ~~Doc~~ | -| **RETURNS** | The processed document. ~~Doc~~ | - -## MultiLabel_TextCategorizer.pipe {#pipe tag="method"} - -Apply the pipe to a stream of documents. This usually happens under the hood -when the `nlp` object is called on a text and all pipeline components are -applied to the `Doc` in order. Both [`__call__`](/api/multilabel_textcategorizer#call) and -[`pipe`](/api/multilabel_textcategorizer#pipe) delegate to the -[`predict`](/api/multilabel_textcategorizer#predict) and -[`set_annotations`](/api/multilabel_textcategorizer#set_annotations) methods. - -> #### Example -> -> ```python -> textcat = nlp.add_pipe("textcat_multilabel") -> for doc in textcat.pipe(docs, batch_size=50): -> pass -> ``` - -| Name | Description | -| -------------- | ------------------------------------------------------------- | -| `stream` | A stream of documents. ~~Iterable[Doc]~~ | -| _keyword-only_ | | -| `batch_size` | The number of documents to buffer. Defaults to `128`. ~~int~~ | -| **YIELDS** | The processed documents in order. ~~Doc~~ | - -## MultiLabel_TextCategorizer.initialize {#initialize tag="method" new="3"} - -Initialize the component for training. `get_examples` should be a function that -returns an iterable of [`Example`](/api/example) objects. The data examples are -used to **initialize the model** of the component and can either be the full -training data or a representative sample. Initialization includes validating the -network, -[inferring missing shapes](https://thinc.ai/docs/usage-models#validation) and -setting up the label scheme based on the data. This method is typically called -by [`Language.initialize`](/api/language#initialize) and lets you customize -arguments it receives via the -[`[initialize.components]`](/api/data-formats#config-initialize) block in the -config. - - - -This method was previously called `begin_training`. - - - -> #### Example -> -> ```python -> textcat = nlp.add_pipe("textcat_multilabel") -> textcat.initialize(lambda: [], nlp=nlp) -> ``` -> -> ```ini -> ### config.cfg -> [initialize.components.textcat_multilabel] -> -> [initialize.components.textcat_multilabel.labels] -> @readers = "spacy.read_labels.v1" -> path = "corpus/labels/textcat.json -> ``` - -| Name | Description | -| ---------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `get_examples` | Function that returns gold-standard annotations in the form of [`Example`](/api/example) objects. ~~Callable[[], Iterable[Example]]~~ | -| _keyword-only_ | | -| `nlp` | The current `nlp` object. Defaults to `None`. ~~Optional[Language]~~ | -| `labels` | The label information to add to the component, as provided by the [`label_data`](#label_data) property after initialization. To generate a reusable JSON file from your data, you should run the [`init labels`](/api/cli#init-labels) command. If no labels are provided, the `get_examples` callback is used to extract the labels from the data, which may be a lot slower. ~~Optional[Iterable[str]]~~ | - -## MultiLabel_TextCategorizer.predict {#predict tag="method"} - -Apply the component's model to a batch of [`Doc`](/api/doc) objects without -modifying them. - -> #### Example -> -> ```python -> textcat = nlp.add_pipe("textcat_multilabel") -> scores = textcat.predict([doc1, doc2]) -> ``` - -| Name | Description | -| ----------- | ------------------------------------------- | -| `docs` | The documents to predict. ~~Iterable[Doc]~~ | -| **RETURNS** | The model's prediction for each document. | - -## MultiLabel_TextCategorizer.set_annotations {#set_annotations tag="method"} - -Modify a batch of [`Doc`](/api/doc) objects using pre-computed scores. - -> #### Example -> -> ```python -> textcat = nlp.add_pipe("textcat_multilabel") -> scores = textcat.predict(docs) -> textcat.set_annotations(docs, scores) -> ``` - -| Name | Description | -| -------- | --------------------------------------------------------- | -| `docs` | The documents to modify. ~~Iterable[Doc]~~ | -| `scores` | The scores to set, produced by `MultiLabel_TextCategorizer.predict`. | - -## MultiLabel_TextCategorizer.update {#update tag="method"} - -Learn from a batch of [`Example`](/api/example) objects containing the -predictions and gold-standard annotations, and update the component's model. -Delegates to [`predict`](/api/multilabel_textcategorizer#predict) and -[`get_loss`](/api/multilabel_textcategorizer#get_loss). - -> #### Example -> -> ```python -> textcat = nlp.add_pipe("textcat_multilabel") -> optimizer = nlp.initialize() -> losses = textcat.update(examples, sgd=optimizer) -> ``` - -| Name | Description | -| ----------------- | ---------------------------------------------------------------------------------------------------------------------------------- | -| `examples` | A batch of [`Example`](/api/example) objects to learn from. ~~Iterable[Example]~~ | -| _keyword-only_ | | -| `drop` | The dropout rate. ~~float~~ | -| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ | -| `losses` | Optional record of the loss during training. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ | -| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ | - -## MultiLabel_TextCategorizer.rehearse {#rehearse tag="method,experimental" new="3"} - -Perform a "rehearsal" update from a batch of data. Rehearsal updates teach the -current model to make predictions similar to an initial model to try to address -the "catastrophic forgetting" problem. This feature is experimental. - -> #### Example -> -> ```python -> textcat = nlp.add_pipe("textcat_multilabel") -> optimizer = nlp.resume_training() -> losses = textcat.rehearse(examples, sgd=optimizer) -> ``` - -| Name | Description | -| -------------- | ------------------------------------------------------------------------------------------------------------------------ | -| `examples` | A batch of [`Example`](/api/example) objects to learn from. ~~Iterable[Example]~~ | -| _keyword-only_ | | -| `drop` | The dropout rate. ~~float~~ | -| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ | -| `losses` | Optional record of the loss during training. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ | -| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ | - -## MultiLabel_TextCategorizer.get_loss {#get_loss tag="method"} - -Find the loss and gradient of loss for the batch of documents and their -predicted scores. - -> #### Example -> -> ```python -> textcat = nlp.add_pipe("textcat_multilabel") -> scores = textcat.predict([eg.predicted for eg in examples]) -> loss, d_loss = textcat.get_loss(examples, scores) -> ``` - -| Name | Description | -| ----------- | --------------------------------------------------------------------------- | -| `examples` | The batch of examples. ~~Iterable[Example]~~ | -| `scores` | Scores representing the model's predictions. | -| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ | - -## MultiLabel_TextCategorizer.score {#score tag="method" new="3"} - -Score a batch of examples. - -> #### Example -> -> ```python -> scores = textcat.score(examples) -> ``` - -| Name | Description | -| ---------------- | -------------------------------------------------------------------------------------------------------------------- | -| `examples` | The examples to score. ~~Iterable[Example]~~ | -| _keyword-only_ | | -| **RETURNS** | The scores, produced by [`Scorer.score_cats`](/api/scorer#score_cats). ~~Dict[str, Union[float, Dict[str, float]]]~~ | - -## MultiLabel_TextCategorizer.create_optimizer {#create_optimizer tag="method"} - -Create an optimizer for the pipeline component. - -> #### Example -> -> ```python -> textcat = nlp.add_pipe("textcat") -> optimizer = textcat.create_optimizer() -> ``` - -| Name | Description | -| ----------- | ---------------------------- | -| **RETURNS** | The optimizer. ~~Optimizer~~ | - -## MultiLabel_TextCategorizer.use_params {#use_params tag="method, contextmanager"} - -Modify the pipe's model to use the given parameter values. - -> #### Example -> -> ```python -> textcat = nlp.add_pipe("textcat") -> with textcat.use_params(optimizer.averages): -> textcat.to_disk("/best_model") -> ``` - -| Name | Description | -| -------- | -------------------------------------------------- | -| `params` | The parameter values to use in the model. ~~dict~~ | - -## MultiLabel_TextCategorizer.add_label {#add_label tag="method"} - -Add a new label to the pipe. Raises an error if the output dimension is already -set, or if the model has already been fully [initialized](#initialize). Note -that you don't have to call this method if you provide a **representative data -sample** to the [`initialize`](#initialize) method. In this case, all labels -found in the sample will be automatically added to the model, and the output -dimension will be [inferred](/usage/layers-architectures#thinc-shape-inference) -automatically. - -> #### Example -> -> ```python -> textcat = nlp.add_pipe("textcat") -> textcat.add_label("MY_LABEL") -> ``` - -| Name | Description | -| ----------- | ----------------------------------------------------------- | -| `label` | The label to add. ~~str~~ | -| **RETURNS** | `0` if the label is already present, otherwise `1`. ~~int~~ | - -## MultiLabel_TextCategorizer.to_disk {#to_disk tag="method"} - -Serialize the pipe to disk. - -> #### Example -> -> ```python -> textcat = nlp.add_pipe("textcat") -> textcat.to_disk("/path/to/textcat") -> ``` - -| Name | Description | -| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------ | -| `path` | A path to a directory, which will be created if it doesn't exist. Paths may be either strings or `Path`-like objects. ~~Union[str, Path]~~ | -| _keyword-only_ | | -| `exclude` | String names of [serialization fields](#serialization-fields) to exclude. ~~Iterable[str]~~ | - -## MultiLabel_TextCategorizer.from_disk {#from_disk tag="method"} - -Load the pipe from disk. Modifies the object in place and returns it. - -> #### Example -> -> ```python -> textcat = nlp.add_pipe("textcat") -> textcat.from_disk("/path/to/textcat") -> ``` - -| Name | Description | -| -------------- | ----------------------------------------------------------------------------------------------- | -| `path` | A path to a directory. Paths may be either strings or `Path`-like objects. ~~Union[str, Path]~~ | -| _keyword-only_ | | -| `exclude` | String names of [serialization fields](#serialization-fields) to exclude. ~~Iterable[str]~~ | -| **RETURNS** | The modified `MultiLabel_TextCategorizer` object. ~~MultiLabel_TextCategorizer~~ | - -## MultiLabel_TextCategorizer.to_bytes {#to_bytes tag="method"} - -> #### Example -> -> ```python -> textcat = nlp.add_pipe("textcat") -> textcat_bytes = textcat.to_bytes() -> ``` - -Serialize the pipe to a bytestring. - -| Name | Description | -| -------------- | ------------------------------------------------------------------------------------------- | -| _keyword-only_ | | -| `exclude` | String names of [serialization fields](#serialization-fields) to exclude. ~~Iterable[str]~~ | -| **RETURNS** | The serialized form of the `MultiLabel_TextCategorizer` object. ~~bytes~~ | - -## MultiLabel_TextCategorizer.from_bytes {#from_bytes tag="method"} - -Load the pipe from a bytestring. Modifies the object in place and returns it. - -> #### Example -> -> ```python -> textcat_bytes = textcat.to_bytes() -> textcat = nlp.add_pipe("textcat") -> textcat.from_bytes(textcat_bytes) -> ``` - -| Name | Description | -| -------------- | ------------------------------------------------------------------------------------------- | -| `bytes_data` | The data to load from. ~~bytes~~ | -| _keyword-only_ | | -| `exclude` | String names of [serialization fields](#serialization-fields) to exclude. ~~Iterable[str]~~ | -| **RETURNS** | The `MultiLabel_TextCategorizer` object. ~~MultiLabel_TextCategorizer~~ | - -## MultiLabel_TextCategorizer.labels {#labels tag="property"} - -The labels currently added to the component. - -> #### Example -> -> ```python -> textcat.add_label("MY_LABEL") -> assert "MY_LABEL" in textcat.labels -> ``` - -| Name | Description | -| ----------- | ------------------------------------------------------ | -| **RETURNS** | The labels added to the component. ~~Tuple[str, ...]~~ | - -## MultiLabel_TextCategorizer.label_data {#label_data tag="property" new="3"} - -The labels currently added to the component and their internal meta information. -This is the data generated by [`init labels`](/api/cli#init-labels) and used by -[`MultiLabel_TextCategorizer.initialize`](/api/multilabel_textcategorizer#initialize) to initialize -the model with a pre-defined label set. - -> #### Example -> -> ```python -> labels = textcat.label_data -> textcat.initialize(lambda: [], nlp=nlp, labels=labels) -> ``` - -| Name | Description | -| ----------- | ---------------------------------------------------------- | -| **RETURNS** | The label data added to the component. ~~Tuple[str, ...]~~ | - -## Serialization fields {#serialization-fields} - -During serialization, spaCy will export several data fields used to restore -different aspects of the object. If needed, you can exclude them from -serialization by passing in the string names via the `exclude` argument. - -> #### Example -> -> ```python -> data = textcat.to_disk("/path", exclude=["vocab"]) -> ``` - -| Name | Description | -| ------- | -------------------------------------------------------------- | -| `vocab` | The shared [`Vocab`](/api/vocab). | -| `cfg` | The config file. You usually don't want to exclude this. | -| `model` | The binary model data. You usually don't want to exclude this. | diff --git a/website/docs/api/textcategorizer.md b/website/docs/api/textcategorizer.md index ac0ab4f27..fdd235b85 100644 --- a/website/docs/api/textcategorizer.md +++ b/website/docs/api/textcategorizer.md @@ -3,15 +3,30 @@ title: TextCategorizer tag: class source: spacy/pipeline/textcat.py new: 2 -teaser: 'Pipeline component for single-label text classification' +teaser: 'Pipeline component for text classification' api_base_class: /api/pipe api_string_name: textcat api_trainable: true --- -The text categorizer predicts **categories over a whole document**. It can learn -one or more labels, and the labels are mutually exclusive - there is exactly one -true label per document. +The text categorizer predicts **categories over a whole document**. and comes in +two flavours: `textcat` and `textcat_multilabel`. When you need to predict +exactly one true label per document, use the `textcat` which has mutually +exclusive labels. If you want to perform multi-label classification and predict +zero, one or more labels per document, use the `textcat_multilabel` component +instead. + +Both components are documented on this page. + + + +In spaCy v2, the `textcat` component could also perform **multi-label +classification**, and even used this setting by default. Since v3.0, the +component `textcat_multilabel` should be used for multi-label classification +instead. The `textcat` component is now used for mutually exclusive classes +only. + + ## Config and implementation {#config} @@ -22,7 +37,7 @@ how the component should be configured. You can override its settings via the [model architectures](/api/architectures) documentation for details on the architectures and their arguments and hyperparameters. -> #### Example +> #### Example (textcat) > > ```python > from spacy.pipeline.textcat import DEFAULT_SINGLE_TEXTCAT_MODEL @@ -33,6 +48,17 @@ architectures and their arguments and hyperparameters. > nlp.add_pipe("textcat", config=config) > ``` +> #### Example (textcat_multilabel) +> +> ```python +> from spacy.pipeline.textcat_multilabel import DEFAULT_MULTI_TEXTCAT_MODEL +> config = { +> "threshold": 0.5, +> "model": DEFAULT_MULTI_TEXTCAT_MODEL, +> } +> nlp.add_pipe("textcat_multilabel", config=config) +> ``` + | Setting | Description | | ----------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------- | | `threshold` | Cutoff to consider a prediction "positive", relevant when printing accuracy results. ~~float~~ | @@ -48,6 +74,7 @@ architectures and their arguments and hyperparameters. > > ```python > # Construction via add_pipe with default model +> # Use 'textcat_multilabel' for multi-label classification > textcat = nlp.add_pipe("textcat") > > # Construction via add_pipe with custom model @@ -55,6 +82,7 @@ architectures and their arguments and hyperparameters. > parser = nlp.add_pipe("textcat", config=config) > > # Construction from class +> # Use 'MultiLabel_TextCategorizer' for multi-label classification > from spacy.pipeline import TextCategorizer > textcat = TextCategorizer(nlp.vocab, model, threshold=0.5) > ``` @@ -161,7 +189,7 @@ This method was previously called `begin_training`. | _keyword-only_ | | | `nlp` | The current `nlp` object. Defaults to `None`. ~~Optional[Language]~~ | | `labels` | The label information to add to the component, as provided by the [`label_data`](#label_data) property after initialization. To generate a reusable JSON file from your data, you should run the [`init labels`](/api/cli#init-labels) command. If no labels are provided, the `get_examples` callback is used to extract the labels from the data, which may be a lot slower. ~~Optional[Iterable[str]]~~ | -| `positive_label` | The positive label for a binary task with exclusive classes, None otherwise and by default. ~~Optional[str]~~ | +| `positive_label` | The positive label for a binary task with exclusive classes, `None` otherwise and by default. This parameter is not available when using the `textcat_multilabel` component. ~~Optional[str]~~ | ## TextCategorizer.predict {#predict tag="method"} @@ -212,14 +240,14 @@ Delegates to [`predict`](/api/textcategorizer#predict) and > losses = textcat.update(examples, sgd=optimizer) > ``` -| Name | Description | -| ----------------- | ---------------------------------------------------------------------------------------------------------------------------------- | -| `examples` | A batch of [`Example`](/api/example) objects to learn from. ~~Iterable[Example]~~ | -| _keyword-only_ | | -| `drop` | The dropout rate. ~~float~~ | -| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ | -| `losses` | Optional record of the loss during training. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ | -| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ | +| Name | Description | +| -------------- | ------------------------------------------------------------------------------------------------------------------------ | +| `examples` | A batch of [`Example`](/api/example) objects to learn from. ~~Iterable[Example]~~ | +| _keyword-only_ | | +| `drop` | The dropout rate. ~~float~~ | +| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ | +| `losses` | Optional record of the loss during training. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ | +| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ | ## TextCategorizer.rehearse {#rehearse tag="method,experimental" new="3"} @@ -273,11 +301,11 @@ Score a batch of examples. > scores = textcat.score(examples) > ``` -| Name | Description | -| ---------------- | -------------------------------------------------------------------------------------------------------------------- | -| `examples` | The examples to score. ~~Iterable[Example]~~ | -| _keyword-only_ | | -| **RETURNS** | The scores, produced by [`Scorer.score_cats`](/api/scorer#score_cats). ~~Dict[str, Union[float, Dict[str, float]]]~~ | +| Name | Description | +| -------------- | -------------------------------------------------------------------------------------------------------------------- | +| `examples` | The examples to score. ~~Iterable[Example]~~ | +| _keyword-only_ | | +| **RETURNS** | The scores, produced by [`Scorer.score_cats`](/api/scorer#score_cats). ~~Dict[str, Union[float, Dict[str, float]]]~~ | ## TextCategorizer.create_optimizer {#create_optimizer tag="method"} diff --git a/website/docs/usage/processing-pipelines.md b/website/docs/usage/processing-pipelines.md index 0058d40dc..909a9c7de 100644 --- a/website/docs/usage/processing-pipelines.md +++ b/website/docs/usage/processing-pipelines.md @@ -223,21 +223,22 @@ available pipeline components and component functions. > ruler = nlp.add_pipe("entity_ruler") > ``` -| String name | Component | Description | -| ----------------- | ----------------------------------------------- | ----------------------------------------------------------------------------------------- | -| `tagger` | [`Tagger`](/api/tagger) | Assign part-of-speech-tags. | -| `parser` | [`DependencyParser`](/api/dependencyparser) | Assign dependency labels. | -| `ner` | [`EntityRecognizer`](/api/entityrecognizer) | Assign named entities. | -| `entity_linker` | [`EntityLinker`](/api/entitylinker) | Assign knowledge base IDs to named entities. Should be added after the entity recognizer. | -| `entity_ruler` | [`EntityRuler`](/api/entityruler) | Assign named entities based on pattern rules and dictionaries. | -| `textcat` | [`TextCategorizer`](/api/textcategorizer) | Assign text categories. | -| `lemmatizer` | [`Lemmatizer`](/api/lemmatizer) | Assign base forms to words. | -| `morphologizer` | [`Morphologizer`](/api/morphologizer) | Assign morphological features and coarse-grained POS tags. | -| `attribute_ruler` | [`AttributeRuler`](/api/attributeruler) | Assign token attribute mappings and rule-based exceptions. | -| `senter` | [`SentenceRecognizer`](/api/sentencerecognizer) | Assign sentence boundaries. | -| `sentencizer` | [`Sentencizer`](/api/sentencizer) | Add rule-based sentence segmentation without the dependency parse. | -| `tok2vec` | [`Tok2Vec`](/api/tok2vec) | Assign token-to-vector embeddings. | -| `transformer` | [`Transformer`](/api/transformer) | Assign the tokens and outputs of a transformer model. | +| String name | Component | Description | +| -------------------- | ---------------------------------------------------- | ----------------------------------------------------------------------------------------- | +| `tagger` | [`Tagger`](/api/tagger) | Assign part-of-speech-tags. | +| `parser` | [`DependencyParser`](/api/dependencyparser) | Assign dependency labels. | +| `ner` | [`EntityRecognizer`](/api/entityrecognizer) | Assign named entities. | +| `entity_linker` | [`EntityLinker`](/api/entitylinker) | Assign knowledge base IDs to named entities. Should be added after the entity recognizer. | +| `entity_ruler` | [`EntityRuler`](/api/entityruler) | Assign named entities based on pattern rules and dictionaries. | +| `textcat` | [`TextCategorizer`](/api/textcategorizer) | Assign text categories: exactly one category is predicted per document. | +| `textcat_multilabel` | [`MultiLabel_TextCategorizer`](/api/textcategorizer) | Assign text categories in a multi-label setting: zero, one or more labels per document. | +| `lemmatizer` | [`Lemmatizer`](/api/lemmatizer) | Assign base forms to words. | +| `morphologizer` | [`Morphologizer`](/api/morphologizer) | Assign morphological features and coarse-grained POS tags. | +| `attribute_ruler` | [`AttributeRuler`](/api/attributeruler) | Assign token attribute mappings and rule-based exceptions. | +| `senter` | [`SentenceRecognizer`](/api/sentencerecognizer) | Assign sentence boundaries. | +| `sentencizer` | [`Sentencizer`](/api/sentencizer) | Add rule-based sentence segmentation without the dependency parse. | +| `tok2vec` | [`Tok2Vec`](/api/tok2vec) | Assign token-to-vector embeddings. | +| `transformer` | [`Transformer`](/api/transformer) | Assign the tokens and outputs of a transformer model. | ### Disabling, excluding and modifying components {#disabling} @@ -400,8 +401,8 @@ vectors available – otherwise, it won't be able to make the same predictions. > ``` > > By default, sourced components will be updated with your data during training. -> If you want to preserve the component as-is, you can "freeze" it if the pipeline -> is not using a shared `Tok2Vec` layer: +> If you want to preserve the component as-is, you can "freeze" it if the +> pipeline is not using a shared `Tok2Vec` layer: > > ```ini > [training] @@ -1244,7 +1245,7 @@ labels = [] # the argument "model" [components.textcat.model] @architectures = "spacy.TextCatBOW.v1" -exclusive_classes = false +exclusive_classes = true ngram_size = 1 no_output_layer = false diff --git a/website/docs/usage/v3.md b/website/docs/usage/v3.md index 5353f9ded..21e99ffc2 100644 --- a/website/docs/usage/v3.md +++ b/website/docs/usage/v3.md @@ -320,14 +320,15 @@ add to your pipeline and customize for your use case: > nlp.add_pipe("lemmatizer") > ``` -| Name | Description | -| ----------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| [`SentenceRecognizer`](/api/sentencerecognizer) | Trainable component for sentence segmentation. | -| [`Morphologizer`](/api/morphologizer) | Trainable component to predict morphological features. | -| [`Lemmatizer`](/api/lemmatizer) | Standalone component for rule-based and lookup lemmatization. | -| [`AttributeRuler`](/api/attributeruler) | Component for setting token attributes using match patterns. | -| [`Transformer`](/api/transformer) | Component for using [transformer models](/usage/embeddings-transformers) in your pipeline, accessing outputs and aligning tokens. Provided via [`spacy-transformers`](https://github.com/explosion/spacy-transformers). | -| [`TrainablePipe`](/api/pipe) | Base class for trainable pipeline components. | +| Name | Description | +| ----------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| [`SentenceRecognizer`](/api/sentencerecognizer) | Trainable component for sentence segmentation. | +| [`Morphologizer`](/api/morphologizer) | Trainable component to predict morphological features. | +| [`Lemmatizer`](/api/lemmatizer) | Standalone component for rule-based and lookup lemmatization. | +| [`AttributeRuler`](/api/attributeruler) | Component for setting token attributes using match patterns. | +| [`Transformer`](/api/transformer) | Component for using [transformer models](/usage/embeddings-transformers) in your pipeline, accessing outputs and aligning tokens. Provided via [`spacy-transformers`](https://github.com/explosion/spacy-transformers). | +| [`TrainablePipe`](/api/pipe) | Base class for trainable pipeline components. | +| [`Multi-label TextCategorizer`](/api/textcategorizer) | Trainable component for multi-label text classification. | @@ -592,6 +593,10 @@ Note that spaCy v3.0 now requires **Python 3.6+**. - Various keyword arguments across functions and methods are now explicitly declared as **keyword-only** arguments. Those arguments are documented accordingly across the API reference using the keyword-only tag. +- The `textcat` pipeline component is now only applicable for classification of + mutually exclusives classes - i.e. one predicted class per input sentence or + document. To perform multi-label classification, use the new + `textcat_multilabel` component instead. ### Removed or renamed API {#incompat-removed} diff --git a/website/src/widgets/quickstart-training.js b/website/src/widgets/quickstart-training.js index 3d2ab0930..849c80f3d 100644 --- a/website/src/widgets/quickstart-training.js +++ b/website/src/widgets/quickstart-training.js @@ -9,6 +9,7 @@ import { htmlToReact } from '../components/util' const DEFAULT_LANG = 'en' const DEFAULT_HARDWARE = 'cpu' const DEFAULT_OPT = 'efficiency' +const DEFAULT_TEXTCAT_EXCLUSIVE = true const COMPONENTS = ['tagger', 'parser', 'ner', 'textcat'] const COMMENT = `# This is an auto-generated partial config. To use it with 'spacy train' # you can run spacy init fill-config to auto-fill all default settings: @@ -27,6 +28,19 @@ const DATA = [ options: COMPONENTS.map(id => ({ id, title: id })), multiple: true, }, + { + id: 'textcat', + title: 'Text Classification', + multiple: true, + options: [ + { + id: 'exclusive', + title: 'exclusive categories', + checked: DEFAULT_TEXTCAT_EXCLUSIVE, + help: 'only one label can apply', + }, + ], + }, { id: 'hardware', title: 'Hardware', @@ -49,14 +63,28 @@ const DATA = [ export default function QuickstartTraining({ id, title, download = 'base_config.cfg' }) { const [lang, setLang] = useState(DEFAULT_LANG) + const [_components, _setComponents] = useState([]) const [components, setComponents] = useState([]) const [[hardware], setHardware] = useState([DEFAULT_HARDWARE]) const [[optimize], setOptimize] = useState([DEFAULT_OPT]) + const [textcatExclusive, setTextcatExclusive] = useState(DEFAULT_TEXTCAT_EXCLUSIVE) + + function updateComponents(value, isExclusive) { + _setComponents(value) + const updated = value.map(c => (c === 'textcat' && !isExclusive ? 'textcat_multilabel' : c)) + setComponents(updated) + } + const setters = { lang: setLang, - components: setComponents, + components: v => updateComponents(v, textcatExclusive), hardware: setHardware, optimize: setOptimize, + textcat: v => { + const isExclusive = v.includes('exclusive') + setTextcatExclusive(isExclusive) + updateComponents(_components, isExclusive) + }, } const reco = GENERATOR_DATA[lang] || GENERATOR_DATA.__default__ const content = generator({ @@ -78,20 +106,24 @@ export default function QuickstartTraining({ id, title, download = 'base_config. { + let data = DATA const langs = site.siteMetadata.languages - DATA[0].dropdown = langs + data[0].dropdown = langs .map(({ name, code }) => ({ id: code, title: name, })) .sort((a, b) => a.title.localeCompare(b.title)) + if (!_components.includes('textcat')) { + data = data.filter(({ id }) => id !== 'textcat') + } return ( Date: Wed, 10 Mar 2021 01:08:05 +1100 Subject: [PATCH 08/10] Update issue templates [ci skip] --- .../ISSUE_TEMPLATE/{03_docs.md => 02_docs.md} | 0 .github/ISSUE_TEMPLATE/02_install.md | 21 ------------------- .../{04_other.md => 03_other.md} | 0 3 files changed, 21 deletions(-) rename .github/ISSUE_TEMPLATE/{03_docs.md => 02_docs.md} (100%) delete mode 100644 .github/ISSUE_TEMPLATE/02_install.md rename .github/ISSUE_TEMPLATE/{04_other.md => 03_other.md} (100%) diff --git a/.github/ISSUE_TEMPLATE/03_docs.md b/.github/ISSUE_TEMPLATE/02_docs.md similarity index 100% rename from .github/ISSUE_TEMPLATE/03_docs.md rename to .github/ISSUE_TEMPLATE/02_docs.md diff --git a/.github/ISSUE_TEMPLATE/02_install.md b/.github/ISSUE_TEMPLATE/02_install.md deleted file mode 100644 index d0790bbdb..000000000 --- a/.github/ISSUE_TEMPLATE/02_install.md +++ /dev/null @@ -1,21 +0,0 @@ ---- -name: "\U000023F3 Installation Problem" -about: Do you have problems installing spaCy, and none of the suggestions in the docs - and other issues helped? - ---- - - -## How to reproduce the problem - - -```bash -# copy-paste the error message here -``` - -## Your Environment - -* Operating System: -* Python Version Used: -* spaCy Version Used: -* Environment Information: diff --git a/.github/ISSUE_TEMPLATE/04_other.md b/.github/ISSUE_TEMPLATE/03_other.md similarity index 100% rename from .github/ISSUE_TEMPLATE/04_other.md rename to .github/ISSUE_TEMPLATE/03_other.md From d746ea6278b3419986c1e6a8359b236a47ab7abc Mon Sep 17 00:00:00 2001 From: Adriane Boyd Date: Tue, 9 Mar 2021 15:35:21 +0100 Subject: [PATCH 09/10] Add warning about GPU selection in Jupyter notebooks (#7075) * Initial warning * Update check * Redo edit * Move jupyter warning to helper method * Add link with details to warnings --- spacy/errors.py | 5 +++++ spacy/language.py | 5 +++++ spacy/util.py | 12 ++++++++++++ website/docs/api/top-level.md | 24 ++++++++++++++++++++++++ website/docs/usage/v3.md | 12 ++++++++++++ 5 files changed, 58 insertions(+) diff --git a/spacy/errors.py b/spacy/errors.py index e50a658d8..4f9e90b57 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -147,6 +147,11 @@ class Warnings: "will be included in the results. For better results, token " "patterns should return matches that are each exactly one token " "long.") + W111 = ("Jupyter notebook detected: if using `prefer_gpu()` or " + "`require_gpu()`, include it in the same cell right before " + "`spacy.load()` to ensure that the model is loaded on the correct " + "device. More information: " + "http://spacy.io/usage/v3#jupyter-notebook-gpu") @add_codes diff --git a/spacy/language.py b/spacy/language.py index 5741ef97c..871dfafaa 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -22,6 +22,7 @@ from .training.initialize import init_vocab, init_tok2vec from .scorer import Scorer from .util import registry, SimpleFrozenList, _pipe, raise_error from .util import SimpleFrozenDict, combine_score_weights, CONFIG_SECTION_ORDER +from .util import warn_if_jupyter_cupy from .lang.tokenizer_exceptions import URL_MATCH, BASE_EXCEPTIONS from .lang.punctuation import TOKENIZER_PREFIXES, TOKENIZER_SUFFIXES from .lang.punctuation import TOKENIZER_INFIXES @@ -1622,6 +1623,10 @@ class Language: or lang_cls is not cls ): raise ValueError(Errors.E943.format(value=type(lang_cls))) + + # Warn about require_gpu usage in jupyter notebook + warn_if_jupyter_cupy() + # Note that we don't load vectors here, instead they get loaded explicitly # inside stuff like the spacy train function. If we loaded them here, # then we would load them twice at runtime: once when we make from config, diff --git a/spacy/util.py b/spacy/util.py index bcb51fe7d..4b82eea8d 100644 --- a/spacy/util.py +++ b/spacy/util.py @@ -1500,3 +1500,15 @@ def raise_error(proc_name, proc, docs, e): def ignore_error(proc_name, proc, docs, e): pass + + +def warn_if_jupyter_cupy(): + """Warn about require_gpu if a jupyter notebook + cupy + mismatched + contextvars vs. thread ops are detected + """ + if is_in_jupyter(): + from thinc.backends.cupy_ops import CupyOps + if CupyOps.xp is not None: + from thinc.backends import contextvars_eq_thread_ops + if not contextvars_eq_thread_ops(): + warnings.warn(Warnings.W111) diff --git a/website/docs/api/top-level.md b/website/docs/api/top-level.md index 37f619f3e..e1d81a5b5 100644 --- a/website/docs/api/top-level.md +++ b/website/docs/api/top-level.md @@ -138,6 +138,14 @@ data has already been allocated on CPU, it will not be moved. Ideally, this function should be called right after importing spaCy and _before_ loading any pipelines. + + +In a Jupyter notebook, run `prefer_gpu()` in the same cell as `spacy.load()` +to ensure that the model is loaded on the correct device. See [more +details](/usage/v3#jupyter-notebook-gpu). + + + > #### Example > > ```python @@ -158,6 +166,14 @@ if no GPU is available. If data has already been allocated on CPU, it will not be moved. Ideally, this function should be called right after importing spaCy and _before_ loading any pipelines. + + +In a Jupyter notebook, run `require_gpu()` in the same cell as `spacy.load()` +to ensure that the model is loaded on the correct device. See [more +details](/usage/v3#jupyter-notebook-gpu). + + + > #### Example > > ```python @@ -177,6 +193,14 @@ Allocate data and perform operations on CPU. If data has already been allocated on GPU, it will not be moved. Ideally, this function should be called right after importing spaCy and _before_ loading any pipelines. + + +In a Jupyter notebook, run `require_cpu()` in the same cell as `spacy.load()` +to ensure that the model is loaded on the correct device. See [more +details](/usage/v3#jupyter-notebook-gpu). + + + > #### Example > > ```python diff --git a/website/docs/usage/v3.md b/website/docs/usage/v3.md index 21e99ffc2..847d4a327 100644 --- a/website/docs/usage/v3.md +++ b/website/docs/usage/v3.md @@ -1179,3 +1179,15 @@ This means that spaCy knows how to initialize `my_component`, even if your package isn't imported. + +#### Using GPUs in Jupyter notebooks {#jupyter-notebook-gpu} + +In Jupyter notebooks, run [`prefer_gpu`](/api/top-level#spacy.prefer_gpu), +[`require_gpu`](/api/top-level#spacy.require_gpu) or +[`require_cpu`](/api/top-level#spacy.require_cpu) in the same cell as +[`spacy.load`](/api/top-level#spacy.load) to ensure that the model is loaded on the correct device. + +Due to a bug related to `contextvars` (see the [bug +report](https://github.com/ipython/ipython/issues/11565)), the GPU settings may +not be preserved correctly across cells, resulting in models being loaded on +the wrong device or only partially on GPU. From 3b911ee5ef2240919b66a0ce55a5d387ceb6f904 Mon Sep 17 00:00:00 2001 From: Adriane Boyd Date: Tue, 9 Mar 2021 16:49:41 +0100 Subject: [PATCH 10/10] Set version to v3.0.4 (#7376) --- spacy/about.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spacy/about.py b/spacy/about.py index c19e1aeaa..4cbfdbad3 100644 --- a/spacy/about.py +++ b/spacy/about.py @@ -1,6 +1,6 @@ # fmt: off __title__ = "spacy" -__version__ = "3.0.3" +__version__ = "3.0.4" __download_url__ = "https://github.com/explosion/spacy-models/releases/download" __compatibility__ = "https://raw.githubusercontent.com/explosion/spacy-models/master/compatibility.json" __projects__ = "https://github.com/explosion/projects"