Small spancat fixes (#8614)

* two small fixes + additional tests

* rename
This commit is contained in:
Sofie Van Landeghem 2021-07-06 14:15:41 +02:00 committed by GitHub
parent 327f83573a
commit 3daf57d70c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 44 additions and 1 deletions

View File

@ -182,6 +182,7 @@ class SpanCategorizer(TrainablePipe):
raise ValueError(Errors.E187)
if label in self.labels:
return 0
self._allow_extra_label()
self.cfg["labels"].append(label)
self.vocab.strings.add(label)
return 1
@ -348,7 +349,7 @@ class SpanCategorizer(TrainablePipe):
self.add_label(label)
for eg in get_examples():
if labels is None:
for span in eg.reference.spans[self.key]:
for span in eg.reference.spans.get(self.key, []):
self.add_label(span.label_)
if len(subbatch) < 10:
subbatch.append(eg)

View File

@ -1,3 +1,4 @@
import pytest
from numpy.testing import assert_equal
from spacy.language import Language
from spacy.training import Example
@ -27,6 +28,47 @@ def make_get_examples(nlp):
return get_examples
def test_no_label():
nlp = Language()
nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY})
with pytest.raises(ValueError):
nlp.initialize()
def test_no_resize():
nlp = Language()
spancat = nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY})
spancat.add_label("Thing")
spancat.add_label("Phrase")
assert spancat.labels == ("Thing", "Phrase")
nlp.initialize()
assert spancat.model.get_dim("nO") == 2
# this throws an error because the spancat can't be resized after initialization
with pytest.raises(ValueError):
spancat.add_label("Stuff")
def test_implicit_labels():
nlp = Language()
spancat = nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY})
assert len(spancat.labels) == 0
train_examples = []
for t in TRAIN_DATA:
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
nlp.initialize(get_examples=lambda: train_examples)
assert spancat.labels == ("PERSON", "LOC")
def test_explicit_labels():
nlp = Language()
spancat = nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY})
assert len(spancat.labels) == 0
spancat.add_label("PERSON")
spancat.add_label("LOC")
nlp.initialize()
assert spancat.labels == ("PERSON", "LOC")
def test_simple_train():
fix_random_seed(0)
nlp = Language()