mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 01:16:28 +03:00
Small spancat fixes (#8614)
* two small fixes + additional tests * rename
This commit is contained in:
parent
327f83573a
commit
3daf57d70c
|
@ -182,6 +182,7 @@ class SpanCategorizer(TrainablePipe):
|
|||
raise ValueError(Errors.E187)
|
||||
if label in self.labels:
|
||||
return 0
|
||||
self._allow_extra_label()
|
||||
self.cfg["labels"].append(label)
|
||||
self.vocab.strings.add(label)
|
||||
return 1
|
||||
|
@ -348,7 +349,7 @@ class SpanCategorizer(TrainablePipe):
|
|||
self.add_label(label)
|
||||
for eg in get_examples():
|
||||
if labels is None:
|
||||
for span in eg.reference.spans[self.key]:
|
||||
for span in eg.reference.spans.get(self.key, []):
|
||||
self.add_label(span.label_)
|
||||
if len(subbatch) < 10:
|
||||
subbatch.append(eg)
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import pytest
|
||||
from numpy.testing import assert_equal
|
||||
from spacy.language import Language
|
||||
from spacy.training import Example
|
||||
|
@ -27,6 +28,47 @@ def make_get_examples(nlp):
|
|||
return get_examples
|
||||
|
||||
|
||||
def test_no_label():
|
||||
nlp = Language()
|
||||
nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY})
|
||||
with pytest.raises(ValueError):
|
||||
nlp.initialize()
|
||||
|
||||
|
||||
def test_no_resize():
|
||||
nlp = Language()
|
||||
spancat = nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY})
|
||||
spancat.add_label("Thing")
|
||||
spancat.add_label("Phrase")
|
||||
assert spancat.labels == ("Thing", "Phrase")
|
||||
nlp.initialize()
|
||||
assert spancat.model.get_dim("nO") == 2
|
||||
# this throws an error because the spancat can't be resized after initialization
|
||||
with pytest.raises(ValueError):
|
||||
spancat.add_label("Stuff")
|
||||
|
||||
|
||||
def test_implicit_labels():
|
||||
nlp = Language()
|
||||
spancat = nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY})
|
||||
assert len(spancat.labels) == 0
|
||||
train_examples = []
|
||||
for t in TRAIN_DATA:
|
||||
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
|
||||
nlp.initialize(get_examples=lambda: train_examples)
|
||||
assert spancat.labels == ("PERSON", "LOC")
|
||||
|
||||
|
||||
def test_explicit_labels():
|
||||
nlp = Language()
|
||||
spancat = nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY})
|
||||
assert len(spancat.labels) == 0
|
||||
spancat.add_label("PERSON")
|
||||
spancat.add_label("LOC")
|
||||
nlp.initialize()
|
||||
assert spancat.labels == ("PERSON", "LOC")
|
||||
|
||||
|
||||
def test_simple_train():
|
||||
fix_random_seed(0)
|
||||
nlp = Language()
|
||||
|
|
Loading…
Reference in New Issue
Block a user