mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 18:26:30 +03:00
Small spancat fixes (#8614)
* two small fixes + additional tests * rename
This commit is contained in:
parent
327f83573a
commit
3daf57d70c
|
@ -182,6 +182,7 @@ class SpanCategorizer(TrainablePipe):
|
||||||
raise ValueError(Errors.E187)
|
raise ValueError(Errors.E187)
|
||||||
if label in self.labels:
|
if label in self.labels:
|
||||||
return 0
|
return 0
|
||||||
|
self._allow_extra_label()
|
||||||
self.cfg["labels"].append(label)
|
self.cfg["labels"].append(label)
|
||||||
self.vocab.strings.add(label)
|
self.vocab.strings.add(label)
|
||||||
return 1
|
return 1
|
||||||
|
@ -348,7 +349,7 @@ class SpanCategorizer(TrainablePipe):
|
||||||
self.add_label(label)
|
self.add_label(label)
|
||||||
for eg in get_examples():
|
for eg in get_examples():
|
||||||
if labels is None:
|
if labels is None:
|
||||||
for span in eg.reference.spans[self.key]:
|
for span in eg.reference.spans.get(self.key, []):
|
||||||
self.add_label(span.label_)
|
self.add_label(span.label_)
|
||||||
if len(subbatch) < 10:
|
if len(subbatch) < 10:
|
||||||
subbatch.append(eg)
|
subbatch.append(eg)
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import pytest
|
||||||
from numpy.testing import assert_equal
|
from numpy.testing import assert_equal
|
||||||
from spacy.language import Language
|
from spacy.language import Language
|
||||||
from spacy.training import Example
|
from spacy.training import Example
|
||||||
|
@ -27,6 +28,47 @@ def make_get_examples(nlp):
|
||||||
return get_examples
|
return get_examples
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_label():
|
||||||
|
nlp = Language()
|
||||||
|
nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY})
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
nlp.initialize()
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_resize():
|
||||||
|
nlp = Language()
|
||||||
|
spancat = nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY})
|
||||||
|
spancat.add_label("Thing")
|
||||||
|
spancat.add_label("Phrase")
|
||||||
|
assert spancat.labels == ("Thing", "Phrase")
|
||||||
|
nlp.initialize()
|
||||||
|
assert spancat.model.get_dim("nO") == 2
|
||||||
|
# this throws an error because the spancat can't be resized after initialization
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
spancat.add_label("Stuff")
|
||||||
|
|
||||||
|
|
||||||
|
def test_implicit_labels():
|
||||||
|
nlp = Language()
|
||||||
|
spancat = nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY})
|
||||||
|
assert len(spancat.labels) == 0
|
||||||
|
train_examples = []
|
||||||
|
for t in TRAIN_DATA:
|
||||||
|
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
|
||||||
|
nlp.initialize(get_examples=lambda: train_examples)
|
||||||
|
assert spancat.labels == ("PERSON", "LOC")
|
||||||
|
|
||||||
|
|
||||||
|
def test_explicit_labels():
|
||||||
|
nlp = Language()
|
||||||
|
spancat = nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY})
|
||||||
|
assert len(spancat.labels) == 0
|
||||||
|
spancat.add_label("PERSON")
|
||||||
|
spancat.add_label("LOC")
|
||||||
|
nlp.initialize()
|
||||||
|
assert spancat.labels == ("PERSON", "LOC")
|
||||||
|
|
||||||
|
|
||||||
def test_simple_train():
|
def test_simple_train():
|
||||||
fix_random_seed(0)
|
fix_random_seed(0)
|
||||||
nlp = Language()
|
nlp = Language()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user