From fc69da0acb655784bdc0d0515e4aeef2a0ca031f Mon Sep 17 00:00:00 2001 From: Ines Montani Date: Sat, 27 Jul 2019 17:30:18 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=92=AB=20Support=20simple=20training=20fo?= =?UTF-8?q?rmat=20in=20nlp.evaluate=20and=20add=20tests=20(#4033)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Support simple training format in nlp.evaluate and add tests * Update docs [ci skip] --- spacy/language.py | 4 ++- spacy/tests/test_language.py | 57 ++++++++++++++++++++++++++++++++++++ website/docs/api/language.md | 14 ++++----- 3 files changed, 67 insertions(+), 8 deletions(-) create mode 100644 spacy/tests/test_language.py diff --git a/spacy/language.py b/spacy/language.py index 39d95c689..bfdd00b79 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -618,7 +618,7 @@ class Language(object): if component_cfg is None: component_cfg = {} docs, golds = zip(*docs_golds) - docs = list(docs) + docs = [self.make_doc(doc) if isinstance(doc, basestring_) else doc for doc in docs] golds = list(golds) for name, pipe in self.pipeline: kwargs = component_cfg.get(name, {}) @@ -628,6 +628,8 @@ class Language(object): else: docs = pipe.pipe(docs, **kwargs) for doc, gold in zip(docs, golds): + if not isinstance(gold, GoldParse): + gold = GoldParse(doc, **gold) if verbose: print(doc) kwargs = component_cfg.get("scorer", {}) diff --git a/spacy/tests/test_language.py b/spacy/tests/test_language.py new file mode 100644 index 000000000..00175fe9a --- /dev/null +++ b/spacy/tests/test_language.py @@ -0,0 +1,57 @@ +# coding: utf-8 +from __future__ import unicode_literals + +import pytest +from spacy.vocab import Vocab +from spacy.language import Language +from spacy.tokens import Doc +from spacy.gold import GoldParse + + +@pytest.fixture +def nlp(): + nlp = Language(Vocab()) + textcat = nlp.create_pipe("textcat") + for label in ("POSITIVE", "NEGATIVE"): + textcat.add_label(label) + nlp.add_pipe(textcat) + nlp.begin_training() + return nlp + + +def test_language_update(nlp): + text = "hello world" + annots = {"cats": {"POSITIVE": 1.0, "NEGATIVE": 0.0}} + doc = Doc(nlp.vocab, words=text.split(" ")) + gold = GoldParse(doc, **annots) + # Update with doc and gold objects + nlp.update([doc], [gold]) + # Update with text and dict + nlp.update([text], [annots]) + # Update with doc object and dict + nlp.update([doc], [annots]) + # Update with text and gold object + nlp.update([text], [gold]) + # Update badly + with pytest.raises(IndexError): + nlp.update([doc], []) + with pytest.raises(IndexError): + nlp.update([], [gold]) + + +def test_language_evaluate(nlp): + text = "hello world" + annots = {"cats": {"POSITIVE": 1.0, "NEGATIVE": 0.0}} + doc = Doc(nlp.vocab, words=text.split(" ")) + gold = GoldParse(doc, **annots) + # Evaluate with doc and gold objects + nlp.evaluate([(doc, gold)]) + # Evaluate with text and dict + nlp.evaluate([(text, annots)]) + # Evaluate with doc object and dict + nlp.evaluate([(doc, annots)]) + # Evaluate with text and gold object + nlp.evaluate([(text, gold)]) + # Evaluate badly + with pytest.raises(Exception): + nlp.evaluate([text, gold]) diff --git a/website/docs/api/language.md b/website/docs/api/language.md index 3245a165b..3fcdeb195 100644 --- a/website/docs/api/language.md +++ b/website/docs/api/language.md @@ -133,13 +133,13 @@ Evaluate a model's pipeline components. > print(scorer.scores) > ``` -| Name | Type | Description | -| -------------------------------------------- | -------- | ------------------------------------------------------------------------------------- | -| `docs_golds` | iterable | Tuples of `Doc` and `GoldParse` objects. | -| `verbose` | bool | Print debugging information. | -| `batch_size` | int | The batch size to use. | -| `scorer` | `Scorer` | Optional [`Scorer`](/api/scorer) to use. If not passed in, a new one will be created. | -| `component_cfg` 2.1 | dict | Config parameters for specific pipeline components, keyed by component name. | +| Name | Type | Description | +| -------------------------------------------- | -------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `docs_golds` | iterable | Tuples of `Doc` and `GoldParse` objects or `(text, annotations)` of raw text and a dict (see [simple training style](/usage/training#training-simple-style)). | +| `verbose` | bool | Print debugging information. | +| `batch_size` | int | The batch size to use. | +| `scorer` | `Scorer` | Optional [`Scorer`](/api/scorer) to use. If not passed in, a new one will be created. | +| `component_cfg` 2.1 | dict | Config parameters for specific pipeline components, keyed by component name. | ## Language.begin_training {#begin_training tag="method"}