From dd207a28be805a1900af6b724a43f36f71f4812e Mon Sep 17 00:00:00 2001 From: Sofie Van Landeghem Date: Thu, 9 Jul 2020 19:43:39 +0200 Subject: [PATCH] cleanup components API (#5726) * add keyword separator for update functions and drop unused "state" * few more Example tests and various small fixes * consistently return losses after update call * eliminate unused tensors field across pipe components * fix name * fix arg name --- spacy/errors.py | 3 + spacy/gold/example.pyx | 4 +- spacy/language.py | 23 ++-- spacy/pipeline/pipes.pyx | 128 +++++++----------- spacy/pipeline/simple_ner.py | 15 +- spacy/pipeline/tok2vec.py | 11 +- spacy/syntax/nn_parser.pyx | 8 +- spacy/tests/regression/test_issue4001-4500.py | 2 +- spacy/tests/test_gold.py | 22 +++ spacy/tokens/doc.pyx | 2 +- 10 files changed, 109 insertions(+), 109 deletions(-) diff --git a/spacy/errors.py b/spacy/errors.py index 5a4e0d0c7..fa432382d 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -69,6 +69,9 @@ class Warnings(object): W027 = ("Found a large training file of {size} bytes. Note that it may " "be more efficient to split your training data into multiple " "smaller JSON files instead.") + W028 = ("Doc.from_array was called with a vector of type '{type}', " + "but is expecting one of type 'uint64' instead. This may result " + "in problems with the vocab further on in the pipeline.") W030 = ("Some entities could not be aligned in the text \"{text}\" with " "entities \"{entities}\". Use " "`spacy.gold.biluo_tags_from_offsets(nlp.make_doc(text), entities)`" diff --git a/spacy/gold/example.pyx b/spacy/gold/example.pyx index 09bc95bff..355578de3 100644 --- a/spacy/gold/example.pyx +++ b/spacy/gold/example.pyx @@ -329,8 +329,8 @@ def _fix_legacy_dict_data(example_dict): for key, value in old_token_dict.items(): if key in ("text", "ids", "brackets"): pass - elif key in remapping: - token_dict[remapping[key]] = value + elif key.lower() in remapping: + token_dict[remapping[key.lower()]] = value else: raise KeyError(Errors.E983.format(key=key, dict="token_annotation", keys=remapping.keys())) text = example_dict.get("text", example_dict.get("raw")) diff --git a/spacy/language.py b/spacy/language.py index a95b6d279..32c8512fc 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -513,20 +513,23 @@ class Language(object): ): """Update the models in the pipeline. - examples (iterable): A batch of `Example` objects. + examples (Iterable[Example]): A batch of examples dummy: Should not be set - serves to catch backwards-incompatible scripts. drop (float): The dropout rate. - sgd (callable): An optimizer. - losses (dict): Dictionary to update with the loss, keyed by component. - component_cfg (dict): Config parameters for specific pipeline + sgd (Optimizer): An optimizer. + losses (Dict[str, float]): Dictionary to update with the loss, keyed by component. + component_cfg (Dict[str, Dict]): Config parameters for specific pipeline components, keyed by component name. + RETURNS (Dict[str, float]): The updated losses dictionary DOCS: https://spacy.io/api/language#update """ if dummy is not None: raise ValueError(Errors.E989) + if losses is None: + losses = {} if len(examples) == 0: - return + return losses if not isinstance(examples, Iterable): raise TypeError(Errors.E978.format(name="language", method="update", types=type(examples))) wrong_types = set([type(eg) for eg in examples if not isinstance(eg, Example)]) @@ -552,6 +555,7 @@ class Language(object): for name, proc in self.pipeline: if hasattr(proc, "model"): proc.model.finish_update(sgd) + return losses def rehearse(self, examples, sgd=None, losses=None, config=None): """Make a "rehearsal" update to the models in the pipeline, to prevent @@ -757,18 +761,17 @@ class Language(object): ): """Process texts as a stream, and yield `Doc` objects in order. - texts (iterator): A sequence of texts to process. + texts (Iterable[str]): A sequence of texts to process. as_tuples (bool): If set to True, inputs should be a sequence of (text, context) tuples. Output will then be a sequence of (doc, context) tuples. Defaults to False. batch_size (int): The number of texts to buffer. - disable (list): Names of the pipeline components to disable. + disable (List[str]): Names of the pipeline components to disable. cleanup (bool): If True, unneeded strings are freed to control memory use. Experimental. - component_cfg (dict): An optional dictionary with extra keyword + component_cfg (Dict[str, Dict]): An optional dictionary with extra keyword arguments for specific components. - n_process (int): Number of processors to process texts, only supported - in Python3. If -1, set `multiprocessing.cpu_count()`. + n_process (int): Number of processors to process texts. If -1, set `multiprocessing.cpu_count()`. YIELDS (Doc): Documents in the order of the original text. DOCS: https://spacy.io/api/language#pipe diff --git a/spacy/pipeline/pipes.pyx b/spacy/pipeline/pipes.pyx index 86c768e9b..c35cb4b68 100644 --- a/spacy/pipeline/pipes.pyx +++ b/spacy/pipeline/pipes.pyx @@ -58,12 +58,8 @@ class Pipe(object): Both __call__ and pipe should delegate to the `predict()` and `set_annotations()` methods. """ - predictions = self.predict([doc]) - if isinstance(predictions, tuple) and len(predictions) == 2: - scores, tensors = predictions - self.set_annotations([doc], scores, tensors=tensors) - else: - self.set_annotations([doc], predictions) + scores = self.predict([doc]) + self.set_annotations([doc], scores) return doc def pipe(self, stream, batch_size=128): @@ -73,12 +69,8 @@ class Pipe(object): and `set_annotations()` methods. """ for docs in util.minibatch(stream, size=batch_size): - predictions = self.predict(docs) - if isinstance(predictions, tuple) and len(tuple) == 2: - scores, tensors = predictions - self.set_annotations(docs, scores, tensors=tensors) - else: - self.set_annotations(docs, predictions) + scores = self.predict(docs) + self.set_annotations(docs, scores) yield from docs def predict(self, docs): @@ -87,7 +79,7 @@ class Pipe(object): """ raise NotImplementedError - def set_annotations(self, docs, scores, tensors=None): + def set_annotations(self, docs, scores): """Modify a batch of documents, using pre-computed scores.""" raise NotImplementedError @@ -281,9 +273,10 @@ class Tagger(Pipe): idx += 1 doc.is_tagged = True - def update(self, examples, drop=0., sgd=None, losses=None, set_annotations=False): - if losses is not None and self.name not in losses: - losses[self.name] = 0. + def update(self, examples, *, drop=0., sgd=None, losses=None, set_annotations=False): + if losses is None: + losses = {} + losses.setdefault(self.name, 0.0) try: if not any(len(eg.predicted) if eg.predicted else 0 for eg in examples): @@ -303,11 +296,11 @@ class Tagger(Pipe): if sgd not in (None, False): self.model.finish_update(sgd) - if losses is not None: - losses[self.name] += loss + losses[self.name] += loss if set_annotations: docs = [eg.predicted for eg in examples] self.set_annotations(docs, self._scores2guesses(tag_scores)) + return losses def rehearse(self, examples, drop=0., sgd=None, losses=None): """Perform a 'rehearsal' update, where we try to match the output of @@ -635,7 +628,7 @@ class MultitaskObjective(Tagger): def labels(self, value): self.cfg["labels"] = value - def set_annotations(self, docs, dep_ids, tensors=None): + def set_annotations(self, docs, dep_ids): pass def begin_training(self, get_examples=lambda: [], pipeline=None, @@ -732,7 +725,7 @@ class ClozeMultitask(Pipe): self.cfg = cfg self.distance = CosineDistance(ignore_zeros=True, normalize=False) # TODO: in config - def set_annotations(self, docs, dep_ids, tensors=None): + def set_annotations(self, docs, dep_ids): pass def begin_training(self, get_examples=lambda: [], pipeline=None, @@ -761,7 +754,7 @@ class ClozeMultitask(Pipe): loss = self.distance.get_loss(prediction, target) return loss, gradient - def update(self, examples, drop=0., set_annotations=False, sgd=None, losses=None): + def update(self, examples, *, drop=0., set_annotations=False, sgd=None, losses=None): pass def rehearse(self, examples, drop=0., sgd=None, losses=None): @@ -809,8 +802,8 @@ class TextCategorizer(Pipe): def pipe(self, stream, batch_size=128): for docs in util.minibatch(stream, size=batch_size): - scores, tensors = self.predict(docs) - self.set_annotations(docs, scores, tensors=tensors) + scores = self.predict(docs) + self.set_annotations(docs, scores) yield from docs def predict(self, docs): @@ -820,22 +813,25 @@ class TextCategorizer(Pipe): # Handle cases where there are no tokens in any docs. xp = get_array_module(tensors) scores = xp.zeros((len(docs), len(self.labels))) - return scores, tensors + return scores scores = self.model.predict(docs) scores = self.model.ops.asarray(scores) - return scores, tensors + return scores - def set_annotations(self, docs, scores, tensors=None): + def set_annotations(self, docs, scores): for i, doc in enumerate(docs): for j, label in enumerate(self.labels): doc.cats[label] = float(scores[i, j]) - def update(self, examples, state=None, drop=0., set_annotations=False, sgd=None, losses=None): + def update(self, examples, *, drop=0., set_annotations=False, sgd=None, losses=None): + if losses is None: + losses = {} + losses.setdefault(self.name, 0.0) try: if not any(len(eg.predicted) if eg.predicted else 0 for eg in examples): # Handle cases where there are no tokens in any docs. - return + return losses except AttributeError: types = set([type(eg) for eg in examples]) raise TypeError(Errors.E978.format(name="TextCategorizer", method="update", types=types)) @@ -847,12 +843,11 @@ class TextCategorizer(Pipe): bp_scores(d_scores) if sgd is not None: self.model.finish_update(sgd) - if losses is not None: - losses.setdefault(self.name, 0.0) - losses[self.name] += loss + losses[self.name] += loss if set_annotations: docs = [eg.predicted for eg in examples] self.set_annotations(docs, scores=scores) + return losses def rehearse(self, examples, drop=0., sgd=None, losses=None): if self._rehearsal_model is None: @@ -1076,12 +1071,13 @@ class EntityLinker(Pipe): sgd = self.create_optimizer() return sgd - def update(self, examples, state=None, set_annotations=False, drop=0.0, sgd=None, losses=None): + def update(self, examples, *, set_annotations=False, drop=0.0, sgd=None, losses=None): self.require_kb() - if losses is not None: - losses.setdefault(self.name, 0.0) + if losses is None: + losses = {} + losses.setdefault(self.name, 0.0) if not examples: - return 0 + return losses sentence_docs = [] try: docs = [eg.predicted for eg in examples] @@ -1124,20 +1120,19 @@ class EntityLinker(Pipe): return 0.0 sentence_encodings, bp_context = self.model.begin_update(sentence_docs) loss, d_scores = self.get_similarity_loss( - scores=sentence_encodings, + sentence_encodings=sentence_encodings, examples=examples ) bp_context(d_scores) if sgd is not None: self.model.finish_update(sgd) - if losses is not None: - losses[self.name] += loss + losses[self.name] += loss if set_annotations: self.set_annotations(docs, predictions) - return loss + return losses - def get_similarity_loss(self, examples, scores): + def get_similarity_loss(self, examples, sentence_encodings): entity_encodings = [] for eg in examples: kb_ids = eg.get_aligned("ENT_KB_ID", as_string=True) @@ -1149,41 +1144,23 @@ class EntityLinker(Pipe): entity_encodings = self.model.ops.asarray(entity_encodings, dtype="float32") - if scores.shape != entity_encodings.shape: + if sentence_encodings.shape != entity_encodings.shape: raise RuntimeError(Errors.E147.format(method="get_similarity_loss", msg="gold entities do not match up")) - gradients = self.distance.get_grad(scores, entity_encodings) - loss = self.distance.get_loss(scores, entity_encodings) + gradients = self.distance.get_grad(sentence_encodings, entity_encodings) + loss = self.distance.get_loss(sentence_encodings, entity_encodings) loss = loss / len(entity_encodings) return loss, gradients - def get_loss(self, examples, scores): - cats = [] - for eg in examples: - kb_ids = eg.get_aligned("ENT_KB_ID", as_string=True) - for ent in eg.predicted.ents: - kb_id = kb_ids[ent.start] - if kb_id: - cats.append([1.0]) - - cats = self.model.ops.asarray(cats, dtype="float32") - if len(scores) != len(cats): - raise RuntimeError(Errors.E147.format(method="get_loss", msg="gold entities do not match up")) - - d_scores = (scores - cats) - loss = (d_scores ** 2).sum() - loss = loss / len(cats) - return loss, d_scores - def __call__(self, doc): - kb_ids, tensors = self.predict([doc]) - self.set_annotations([doc], kb_ids, tensors=tensors) + kb_ids = self.predict([doc]) + self.set_annotations([doc], kb_ids) return doc def pipe(self, stream, batch_size=128): for docs in util.minibatch(stream, size=batch_size): - kb_ids, tensors = self.predict(docs) - self.set_annotations(docs, kb_ids, tensors=tensors) + kb_ids = self.predict(docs) + self.set_annotations(docs, kb_ids) yield from docs def predict(self, docs): @@ -1191,10 +1168,9 @@ class EntityLinker(Pipe): self.require_kb() entity_count = 0 final_kb_ids = [] - final_tensors = [] if not docs: - return final_kb_ids, final_tensors + return final_kb_ids if isinstance(docs, Doc): docs = [docs] @@ -1228,21 +1204,18 @@ class EntityLinker(Pipe): if to_discard and ent.label_ in to_discard: # ignoring this entity - setting to NIL final_kb_ids.append(self.NIL) - final_tensors.append(sentence_encoding) else: candidates = self.kb.get_candidates(ent.text) if not candidates: # no prediction possible for this entity - setting to NIL final_kb_ids.append(self.NIL) - final_tensors.append(sentence_encoding) elif len(candidates) == 1: # shortcut for efficiency reasons: take the 1 candidate # TODO: thresholding final_kb_ids.append(candidates[0].entity_) - final_tensors.append(sentence_encoding) else: random.shuffle(candidates) @@ -1271,14 +1244,13 @@ class EntityLinker(Pipe): best_index = scores.argmax().item() best_candidate = candidates[best_index] final_kb_ids.append(best_candidate.entity_) - final_tensors.append(sentence_encoding) - if not (len(final_tensors) == len(final_kb_ids) == entity_count): + if not (len(final_kb_ids) == entity_count): raise RuntimeError(Errors.E147.format(method="predict", msg="result variables not of equal length")) - return final_kb_ids, final_tensors + return final_kb_ids - def set_annotations(self, docs, kb_ids, tensors=None): + def set_annotations(self, docs, kb_ids): count_ents = len([ent for doc in docs for ent in doc.ents]) if count_ents != len(kb_ids): raise ValueError(Errors.E148.format(ents=count_ents, ids=len(kb_ids))) @@ -1394,11 +1366,7 @@ class Sentencizer(Pipe): def pipe(self, stream, batch_size=128): for docs in util.minibatch(stream, size=batch_size): predictions = self.predict(docs) - if isinstance(predictions, tuple) and len(tuple) == 2: - scores, tensors = predictions - self.set_annotations(docs, scores, tensors=tensors) - else: - self.set_annotations(docs, predictions) + self.set_annotations(docs, predictions) yield from docs def predict(self, docs): @@ -1429,7 +1397,7 @@ class Sentencizer(Pipe): guesses.append(doc_guesses) return guesses - def set_annotations(self, docs, batch_tag_ids, tensors=None): + def set_annotations(self, docs, batch_tag_ids): if isinstance(docs, Doc): docs = [docs] cdef Doc doc diff --git a/spacy/pipeline/simple_ner.py b/spacy/pipeline/simple_ner.py index e4a1e15e9..bf5783b1a 100644 --- a/spacy/pipeline/simple_ner.py +++ b/spacy/pipeline/simple_ner.py @@ -57,7 +57,7 @@ class SimpleNER(Pipe): scores = self.model.predict(docs) return scores - def set_annotations(self, docs: List[Doc], scores: List[Floats2d], tensors=None): + def set_annotations(self, docs: List[Doc], scores: List[Floats2d]): """Set entities on a batch of documents from a batch of scores.""" tag_names = self.get_tag_names() for i, doc in enumerate(docs): @@ -67,9 +67,12 @@ class SimpleNER(Pipe): tags = iob_to_biluo(tags) doc.ents = spans_from_biluo_tags(doc, tags) - def update(self, examples, set_annotations=False, drop=0.0, sgd=None, losses=None): + def update(self, examples, *, set_annotations=False, drop=0.0, sgd=None, losses=None): + if losses is None: + losses = {} + losses.setdefault("ner", 0.0) if not any(_has_ner(eg) for eg in examples): - return 0 + return losses docs = [eg.predicted for eg in examples] set_dropout_rate(self.model, drop) scores, bp_scores = self.model.begin_update(docs) @@ -79,10 +82,8 @@ class SimpleNER(Pipe): self.set_annotations(docs, scores) if sgd is not None: self.model.finish_update(sgd) - if losses is not None: - losses.setdefault("ner", 0.0) - losses["ner"] += loss - return loss + losses["ner"] += loss + return losses def get_loss(self, examples, scores): loss = 0 diff --git a/spacy/pipeline/tok2vec.py b/spacy/pipeline/tok2vec.py index a06513a73..56afb3925 100644 --- a/spacy/pipeline/tok2vec.py +++ b/spacy/pipeline/tok2vec.py @@ -83,12 +83,14 @@ class Tok2Vec(Pipe): assert tokvecs.shape[0] == len(doc) doc.tensor = tokvecs - def update(self, examples, drop=0.0, sgd=None, losses=None, set_annotations=False): + def update(self, examples, *, drop=0.0, sgd=None, losses=None, set_annotations=False): """Update the model. - examples (iterable): A batch of examples + examples (Iterable[Example]): A batch of examples drop (float): The droput rate. - sgd (callable): An optimizer. - RETURNS (dict): Results from the update. + sgd (Optimizer): An optimizer. + losses (Dict[str, float]): Dictionary to update with the loss, keyed by component. + set_annotations (bool): whether or not to update the examples with the predictions + RETURNS (Dict[str, float]): The updated losses dictionary """ if losses is None: losses = {} @@ -124,6 +126,7 @@ class Tok2Vec(Pipe): self.listeners[-1].receive(batch_id, tokvecs, backprop) if set_annotations: self.set_annotations(docs, tokvecs) + return losses def get_loss(self, docs, golds, scores): pass diff --git a/spacy/syntax/nn_parser.pyx b/spacy/syntax/nn_parser.pyx index 8bac8cd89..043d8d681 100644 --- a/spacy/syntax/nn_parser.pyx +++ b/spacy/syntax/nn_parser.pyx @@ -153,7 +153,7 @@ cdef class Parser: doc (Doc): The document to be processed. """ states = self.predict([doc]) - self.set_annotations([doc], states, tensors=None) + self.set_annotations([doc], states) return doc def pipe(self, docs, int batch_size=256): @@ -170,7 +170,7 @@ cdef class Parser: for subbatch in util.minibatch(by_length, size=max(batch_size//4, 2)): subbatch = list(subbatch) parse_states = self.predict(subbatch) - self.set_annotations(subbatch, parse_states, tensors=None) + self.set_annotations(subbatch, parse_states) yield from batch_in_order def predict(self, docs): @@ -222,7 +222,7 @@ cdef class Parser: unfinished.clear() free_activations(&activations) - def set_annotations(self, docs, states, tensors=None): + def set_annotations(self, docs, states): cdef StateClass state cdef Doc doc for i, (state, doc) in enumerate(zip(states, docs)): @@ -263,7 +263,7 @@ cdef class Parser: states[i].push_hist(guess) free(is_valid) - def update(self, examples, drop=0., set_annotations=False, sgd=None, losses=None): + def update(self, examples, *, drop=0., set_annotations=False, sgd=None, losses=None): cdef StateClass state if losses is None: losses = {} diff --git a/spacy/tests/regression/test_issue4001-4500.py b/spacy/tests/regression/test_issue4001-4500.py index 2981c6428..626856e9e 100644 --- a/spacy/tests/regression/test_issue4001-4500.py +++ b/spacy/tests/regression/test_issue4001-4500.py @@ -302,7 +302,7 @@ def test_multiple_predictions(): def predict(self, docs): return ([1, 2, 3], [4, 5, 6]) - def set_annotations(self, docs, scores, tensors=None): + def set_annotations(self, docs, scores): return docs nlp = Language() diff --git a/spacy/tests/test_gold.py b/spacy/tests/test_gold.py index 7d3033560..0b0ba5cad 100644 --- a/spacy/tests/test_gold.py +++ b/spacy/tests/test_gold.py @@ -1,3 +1,4 @@ +import numpy from spacy.errors import AlignmentError from spacy.gold import biluo_tags_from_offsets, offsets_from_biluo_tags from spacy.gold import spans_from_biluo_tags, iob_to_biluo @@ -154,6 +155,27 @@ def test_gold_biluo_misalign(en_vocab): assert tags == ["O", "O", "O", "-", "-", "-"] +def test_example_constructor(en_vocab): + words = ["I", "like", "stuff"] + tags = ["NOUN", "VERB", "NOUN"] + tag_ids = [en_vocab.strings.add(tag) for tag in tags] + predicted = Doc(en_vocab, words=words) + reference = Doc(en_vocab, words=words) + reference = reference.from_array("TAG", numpy.array(tag_ids, dtype="uint64")) + example = Example(predicted, reference) + tags = example.get_aligned("TAG", as_string=True) + assert tags == ["NOUN", "VERB", "NOUN"] + + +def test_example_from_dict_tags(en_vocab): + words = ["I", "like", "stuff"] + tags = ["NOUN", "VERB", "NOUN"] + predicted = Doc(en_vocab, words=words) + example = Example.from_dict(predicted, {"TAGS": tags}) + tags = example.get_aligned("TAG", as_string=True) + assert tags == ["NOUN", "VERB", "NOUN"] + + def test_example_from_dict_no_ner(en_vocab): words = ["a", "b", "c", "d"] spaces = [True, True, False, True] diff --git a/spacy/tokens/doc.pyx b/spacy/tokens/doc.pyx index ca9230d98..f28bd3374 100644 --- a/spacy/tokens/doc.pyx +++ b/spacy/tokens/doc.pyx @@ -803,7 +803,7 @@ cdef class Doc: attrs = [(IDS[id_.upper()] if hasattr(id_, "upper") else id_) for id_ in attrs] if array.dtype != numpy.uint64: - warnings.warn(Warnings.W101.format(type=array.dtype)) + warnings.warn(Warnings.W028.format(type=array.dtype)) if SENT_START in attrs and HEAD in attrs: raise ValueError(Errors.E032)