mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 17:06:29 +03:00
Tagger robustness (#6580)
* require labels in taggers * ensure tagger works with incomplete data
This commit is contained in:
parent
e10295c9fd
commit
0a923a7915
|
@ -256,7 +256,7 @@ class Tagger(TrainablePipe):
|
|||
DOCS: https://nightly.spacy.io/api/tagger#get_loss
|
||||
"""
|
||||
validate_examples(examples, "Tagger.get_loss")
|
||||
loss_func = SequenceCategoricalCrossentropy(names=self.labels, normalize=False)
|
||||
loss_func = SequenceCategoricalCrossentropy(names=self.labels, normalize=False, missing_value="")
|
||||
truths = [eg.get_aligned("TAG", as_string=True) for eg in examples]
|
||||
d_scores, loss = loss_func(scores, truths)
|
||||
if self.model.ops.xp.isnan(loss):
|
||||
|
@ -295,6 +295,7 @@ class Tagger(TrainablePipe):
|
|||
gold_tags = example.get_aligned("TAG", as_string=True)
|
||||
gold_array = [[1.0 if tag == gold_tag else 0.0 for tag in self.labels] for gold_tag in gold_tags]
|
||||
label_sample.append(self.model.ops.asarray(gold_array, dtype="float32"))
|
||||
self._require_labels()
|
||||
assert len(doc_sample) > 0, Errors.E923.format(name=self.name)
|
||||
assert len(label_sample) > 0, Errors.E923.format(name=self.name)
|
||||
self.model.initialize(X=doc_sample, Y=label_sample)
|
||||
|
|
|
@ -36,6 +36,10 @@ TRAIN_DATA = [
|
|||
("Eat blue ham", {"tags": ["V", "J", "N"]}),
|
||||
]
|
||||
|
||||
PARTIAL_DATA = [
|
||||
("I like green eggs", {"tags": ["", "V", "J", ""]}),
|
||||
]
|
||||
|
||||
|
||||
def test_no_label():
|
||||
nlp = Language()
|
||||
|
@ -87,6 +91,41 @@ def test_initialize_examples():
|
|||
nlp.initialize(get_examples=train_examples)
|
||||
|
||||
|
||||
def test_no_data():
|
||||
# Test that the tagger provides a nice error when there's no tagging data / labels
|
||||
TEXTCAT_DATA = [
|
||||
("I'm so happy.", {"cats": {"POSITIVE": 1.0, "NEGATIVE": 0.0}}),
|
||||
("I'm so angry", {"cats": {"POSITIVE": 0.0, "NEGATIVE": 1.0}}),
|
||||
]
|
||||
nlp = English()
|
||||
nlp.add_pipe("tagger")
|
||||
nlp.add_pipe("textcat")
|
||||
train_examples = []
|
||||
for t in TEXTCAT_DATA:
|
||||
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
|
||||
with pytest.raises(ValueError):
|
||||
nlp.initialize(get_examples=lambda: train_examples)
|
||||
|
||||
|
||||
def test_incomplete_data():
|
||||
# Test that the tagger works with incomplete information
|
||||
nlp = English()
|
||||
nlp.add_pipe("tagger")
|
||||
train_examples = []
|
||||
for t in PARTIAL_DATA:
|
||||
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
|
||||
optimizer = nlp.initialize(get_examples=lambda: train_examples)
|
||||
for i in range(50):
|
||||
losses = {}
|
||||
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
||||
assert losses["tagger"] < 0.00001
|
||||
|
||||
# test the trained model
|
||||
test_text = "I like blue eggs"
|
||||
doc = nlp(test_text)
|
||||
assert doc[1].tag_ is "V"
|
||||
assert doc[2].tag_ is "J"
|
||||
|
||||
def test_overfitting_IO():
|
||||
# Simple test to try and quickly overfit the tagger - ensuring the ML models work correctly
|
||||
nlp = English()
|
||||
|
|
Loading…
Reference in New Issue
Block a user