From 3daf57d70cea09d4f1ab14cbc1bc4f9d38adfab2 Mon Sep 17 00:00:00 2001 From: Sofie Van Landeghem Date: Tue, 6 Jul 2021 14:15:41 +0200 Subject: [PATCH] Small spancat fixes (#8614) * two small fixes + additional tests * rename --- spacy/pipeline/spancat.py | 3 +- spacy/tests/pipeline/test_spancat.py | 42 ++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 1 deletion(-) diff --git a/spacy/pipeline/spancat.py b/spacy/pipeline/spancat.py index fdf6f9f5e..f5d3d8da9 100644 --- a/spacy/pipeline/spancat.py +++ b/spacy/pipeline/spancat.py @@ -182,6 +182,7 @@ class SpanCategorizer(TrainablePipe): raise ValueError(Errors.E187) if label in self.labels: return 0 + self._allow_extra_label() self.cfg["labels"].append(label) self.vocab.strings.add(label) return 1 @@ -348,7 +349,7 @@ class SpanCategorizer(TrainablePipe): self.add_label(label) for eg in get_examples(): if labels is None: - for span in eg.reference.spans[self.key]: + for span in eg.reference.spans.get(self.key, []): self.add_label(span.label_) if len(subbatch) < 10: subbatch.append(eg) diff --git a/spacy/tests/pipeline/test_spancat.py b/spacy/tests/pipeline/test_spancat.py index f70df7478..5e19f5e5e 100644 --- a/spacy/tests/pipeline/test_spancat.py +++ b/spacy/tests/pipeline/test_spancat.py @@ -1,3 +1,4 @@ +import pytest from numpy.testing import assert_equal from spacy.language import Language from spacy.training import Example @@ -27,6 +28,47 @@ def make_get_examples(nlp): return get_examples +def test_no_label(): + nlp = Language() + nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY}) + with pytest.raises(ValueError): + nlp.initialize() + + +def test_no_resize(): + nlp = Language() + spancat = nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY}) + spancat.add_label("Thing") + spancat.add_label("Phrase") + assert spancat.labels == ("Thing", "Phrase") + nlp.initialize() + assert spancat.model.get_dim("nO") == 2 + # this throws an error because the spancat can't be resized after initialization + with pytest.raises(ValueError): + spancat.add_label("Stuff") + + +def test_implicit_labels(): + nlp = Language() + spancat = nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY}) + assert len(spancat.labels) == 0 + train_examples = [] + for t in TRAIN_DATA: + train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1])) + nlp.initialize(get_examples=lambda: train_examples) + assert spancat.labels == ("PERSON", "LOC") + + +def test_explicit_labels(): + nlp = Language() + spancat = nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY}) + assert len(spancat.labels) == 0 + spancat.add_label("PERSON") + spancat.add_label("LOC") + nlp.initialize() + assert spancat.labels == ("PERSON", "LOC") + + def test_simple_train(): fix_random_seed(0) nlp = Language()