diff --git a/spacy/language.py b/spacy/language.py index d2b89029d..fb86689bc 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -1248,17 +1248,12 @@ class Language: component_cfg[name].setdefault("drop", drop) pipe_kwargs[name].setdefault("batch_size", self.batch_size) for name, proc in self.pipeline: - # ignore statements are used here because mypy ignores hasattr - if name not in exclude and hasattr(proc, "update"): - proc.update(examples, sgd=None, losses=losses, **component_cfg[name]) # type: ignore - if sgd not in (None, False): - if ( - name not in exclude - and isinstance(proc, ty.TrainableComponent) - and proc.is_trainable - and proc.model not in (True, False, None) - ): - proc.finish_update(sgd) + if ( + name not in exclude + and isinstance(proc, ty.TrainableComponent) + and proc.is_trainable + ): + proc.update(examples, sgd=None, losses=losses, **component_cfg[name]) if name in annotates: for doc, eg in zip( _pipe( @@ -1271,6 +1266,17 @@ class Language: examples, ): eg.predicted = doc + # Only finish the update after all component updates are done. Some + # components may share weights (such as tok2vec) and we only want + # to apply weight updates after all gradients are accumulated. + for name, proc in self.pipeline: + if ( + name not in exclude + and isinstance(proc, ty.TrainableComponent) + and proc.is_trainable + ): + proc.finish_update(sgd) + return losses def rehearse( diff --git a/spacy/tests/pipeline/test_annotates_on_update.py b/spacy/tests/pipeline/test_annotates_on_update.py index 869b8b874..10fb22c97 100644 --- a/spacy/tests/pipeline/test_annotates_on_update.py +++ b/spacy/tests/pipeline/test_annotates_on_update.py @@ -54,9 +54,11 @@ def test_annotates_on_update(): return AssertSents(name) class AssertSents: + model = None + is_trainable = True + def __init__(self, name, **cfg): self.name = name - pass def __call__(self, doc): if not doc.has_annotation("SENT_START"): @@ -64,10 +66,16 @@ def test_annotates_on_update(): return doc def update(self, examples, *, drop=0.0, sgd=None, losses=None): + losses.setdefault(self.name, 0.0) + for example in examples: if not example.predicted.has_annotation("SENT_START"): raise ValueError("No sents") - return {} + + return losses + + def finish_update(self, sgd=None): + pass nlp = English() nlp.add_pipe("sentencizer") diff --git a/spacy/tests/test_language.py b/spacy/tests/test_language.py index f2d6d5fc0..3d0905dd3 100644 --- a/spacy/tests/test_language.py +++ b/spacy/tests/test_language.py @@ -10,8 +10,9 @@ from spacy.training import Example from spacy.lang.en import English from spacy.lang.de import German from spacy.util import registry, ignore_error, raise_error, find_matching_language +from spacy.util import load_model_from_config import spacy -from thinc.api import CupyOps, NumpyOps, get_current_ops +from thinc.api import Config, CupyOps, NumpyOps, get_array_module, get_current_ops from .util import add_vecs_to_vocab, assert_docs_equal @@ -25,6 +26,51 @@ try: except ImportError: pass +TAGGER_CFG_STRING = """ + [nlp] + lang = "en" + pipeline = ["tok2vec","tagger"] + + [components] + + [components.tagger] + factory = "tagger" + + [components.tagger.model] + @architectures = "spacy.Tagger.v2" + nO = null + + [components.tagger.model.tok2vec] + @architectures = "spacy.Tok2VecListener.v1" + width = ${components.tok2vec.model.encode.width} + + [components.tok2vec] + factory = "tok2vec" + + [components.tok2vec.model] + @architectures = "spacy.Tok2Vec.v2" + + [components.tok2vec.model.embed] + @architectures = "spacy.MultiHashEmbed.v1" + width = ${components.tok2vec.model.encode.width} + rows = [2000, 1000, 1000, 1000] + attrs = ["NORM", "PREFIX", "SUFFIX", "SHAPE"] + include_static_vectors = false + + [components.tok2vec.model.encode] + @architectures = "spacy.MaxoutWindowEncoder.v2" + width = 96 + depth = 4 + window_size = 1 + maxout_pieces = 3 + """ + + +TAGGER_TRAIN_DATA = [ + ("I like green eggs", {"tags": ["N", "V", "J", "N"]}), + ("Eat blue ham", {"tags": ["V", "J", "N"]}), +] + TAGGER_TRAIN_DATA = [ ("I like green eggs", {"tags": ["N", "V", "J", "N"]}), @@ -91,6 +137,26 @@ def test_language_update(nlp): example = Example.from_dict(doc, wrongkeyannots) +def test_language_update_updates(): + config = Config().from_str(TAGGER_CFG_STRING) + nlp = load_model_from_config(config, auto_fill=True, validate=True) + + train_examples = [] + for t in TAGGER_TRAIN_DATA: + train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1])) + + optimizer = nlp.initialize(get_examples=lambda: train_examples) + + docs_before_update = list(nlp.pipe([eg.predicted.copy() for eg in train_examples])) + nlp.update(train_examples, sgd=optimizer) + docs_after_update = list(nlp.pipe([eg.predicted.copy() for eg in train_examples])) + + xp = get_array_module(docs_after_update[0].tensor) + assert xp.any( + xp.not_equal(docs_before_update[0].tensor, docs_after_update[0].tensor) + ) + + def test_language_evaluate(nlp): text = "hello world" annots = {"doc_annotation": {"cats": {"POSITIVE": 1.0, "NEGATIVE": 0.0}}} diff --git a/spacy/training/loop.py b/spacy/training/loop.py index fc929816d..fcc023a0d 100644 --- a/spacy/training/loop.py +++ b/spacy/training/loop.py @@ -210,7 +210,7 @@ def train_while_improving( subbatch, drop=dropout, losses=losses, - sgd=False, # type: ignore[arg-type] + sgd=None, exclude=exclude, annotates=annotating_components, )