Tagger robustness (#6580)

* require labels in taggers

* ensure tagger works with incomplete data
This commit is contained in:
Sofie Van Landeghem 2020-12-18 11:51:47 +01:00 committed by GitHub
parent e10295c9fd
commit 0a923a7915
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 41 additions and 1 deletions

View File

@ -256,7 +256,7 @@ class Tagger(TrainablePipe):
DOCS: https://nightly.spacy.io/api/tagger#get_loss
"""
validate_examples(examples, "Tagger.get_loss")
loss_func = SequenceCategoricalCrossentropy(names=self.labels, normalize=False)
loss_func = SequenceCategoricalCrossentropy(names=self.labels, normalize=False, missing_value="")
truths = [eg.get_aligned("TAG", as_string=True) for eg in examples]
d_scores, loss = loss_func(scores, truths)
if self.model.ops.xp.isnan(loss):
@ -295,6 +295,7 @@ class Tagger(TrainablePipe):
gold_tags = example.get_aligned("TAG", as_string=True)
gold_array = [[1.0 if tag == gold_tag else 0.0 for tag in self.labels] for gold_tag in gold_tags]
label_sample.append(self.model.ops.asarray(gold_array, dtype="float32"))
self._require_labels()
assert len(doc_sample) > 0, Errors.E923.format(name=self.name)
assert len(label_sample) > 0, Errors.E923.format(name=self.name)
self.model.initialize(X=doc_sample, Y=label_sample)

View File

@ -36,6 +36,10 @@ TRAIN_DATA = [
("Eat blue ham", {"tags": ["V", "J", "N"]}),
]
PARTIAL_DATA = [
("I like green eggs", {"tags": ["", "V", "J", ""]}),
]
def test_no_label():
nlp = Language()
@ -87,6 +91,41 @@ def test_initialize_examples():
nlp.initialize(get_examples=train_examples)
def test_no_data():
# Test that the tagger provides a nice error when there's no tagging data / labels
TEXTCAT_DATA = [
("I'm so happy.", {"cats": {"POSITIVE": 1.0, "NEGATIVE": 0.0}}),
("I'm so angry", {"cats": {"POSITIVE": 0.0, "NEGATIVE": 1.0}}),
]
nlp = English()
nlp.add_pipe("tagger")
nlp.add_pipe("textcat")
train_examples = []
for t in TEXTCAT_DATA:
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
with pytest.raises(ValueError):
nlp.initialize(get_examples=lambda: train_examples)
def test_incomplete_data():
# Test that the tagger works with incomplete information
nlp = English()
nlp.add_pipe("tagger")
train_examples = []
for t in PARTIAL_DATA:
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
optimizer = nlp.initialize(get_examples=lambda: train_examples)
for i in range(50):
losses = {}
nlp.update(train_examples, sgd=optimizer, losses=losses)
assert losses["tagger"] < 0.00001
# test the trained model
test_text = "I like blue eggs"
doc = nlp(test_text)
assert doc[1].tag_ is "V"
assert doc[2].tag_ is "J"
def test_overfitting_IO():
# Simple test to try and quickly overfit the tagger - ensuring the ML models work correctly
nlp = English()