mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 10:16:27 +03:00
Tagger robustness (#6580)
* require labels in taggers * ensure tagger works with incomplete data
This commit is contained in:
parent
e10295c9fd
commit
0a923a7915
|
@ -256,7 +256,7 @@ class Tagger(TrainablePipe):
|
||||||
DOCS: https://nightly.spacy.io/api/tagger#get_loss
|
DOCS: https://nightly.spacy.io/api/tagger#get_loss
|
||||||
"""
|
"""
|
||||||
validate_examples(examples, "Tagger.get_loss")
|
validate_examples(examples, "Tagger.get_loss")
|
||||||
loss_func = SequenceCategoricalCrossentropy(names=self.labels, normalize=False)
|
loss_func = SequenceCategoricalCrossentropy(names=self.labels, normalize=False, missing_value="")
|
||||||
truths = [eg.get_aligned("TAG", as_string=True) for eg in examples]
|
truths = [eg.get_aligned("TAG", as_string=True) for eg in examples]
|
||||||
d_scores, loss = loss_func(scores, truths)
|
d_scores, loss = loss_func(scores, truths)
|
||||||
if self.model.ops.xp.isnan(loss):
|
if self.model.ops.xp.isnan(loss):
|
||||||
|
@ -295,6 +295,7 @@ class Tagger(TrainablePipe):
|
||||||
gold_tags = example.get_aligned("TAG", as_string=True)
|
gold_tags = example.get_aligned("TAG", as_string=True)
|
||||||
gold_array = [[1.0 if tag == gold_tag else 0.0 for tag in self.labels] for gold_tag in gold_tags]
|
gold_array = [[1.0 if tag == gold_tag else 0.0 for tag in self.labels] for gold_tag in gold_tags]
|
||||||
label_sample.append(self.model.ops.asarray(gold_array, dtype="float32"))
|
label_sample.append(self.model.ops.asarray(gold_array, dtype="float32"))
|
||||||
|
self._require_labels()
|
||||||
assert len(doc_sample) > 0, Errors.E923.format(name=self.name)
|
assert len(doc_sample) > 0, Errors.E923.format(name=self.name)
|
||||||
assert len(label_sample) > 0, Errors.E923.format(name=self.name)
|
assert len(label_sample) > 0, Errors.E923.format(name=self.name)
|
||||||
self.model.initialize(X=doc_sample, Y=label_sample)
|
self.model.initialize(X=doc_sample, Y=label_sample)
|
||||||
|
|
|
@ -36,6 +36,10 @@ TRAIN_DATA = [
|
||||||
("Eat blue ham", {"tags": ["V", "J", "N"]}),
|
("Eat blue ham", {"tags": ["V", "J", "N"]}),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
PARTIAL_DATA = [
|
||||||
|
("I like green eggs", {"tags": ["", "V", "J", ""]}),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_no_label():
|
def test_no_label():
|
||||||
nlp = Language()
|
nlp = Language()
|
||||||
|
@ -87,6 +91,41 @@ def test_initialize_examples():
|
||||||
nlp.initialize(get_examples=train_examples)
|
nlp.initialize(get_examples=train_examples)
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_data():
|
||||||
|
# Test that the tagger provides a nice error when there's no tagging data / labels
|
||||||
|
TEXTCAT_DATA = [
|
||||||
|
("I'm so happy.", {"cats": {"POSITIVE": 1.0, "NEGATIVE": 0.0}}),
|
||||||
|
("I'm so angry", {"cats": {"POSITIVE": 0.0, "NEGATIVE": 1.0}}),
|
||||||
|
]
|
||||||
|
nlp = English()
|
||||||
|
nlp.add_pipe("tagger")
|
||||||
|
nlp.add_pipe("textcat")
|
||||||
|
train_examples = []
|
||||||
|
for t in TEXTCAT_DATA:
|
||||||
|
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
nlp.initialize(get_examples=lambda: train_examples)
|
||||||
|
|
||||||
|
|
||||||
|
def test_incomplete_data():
|
||||||
|
# Test that the tagger works with incomplete information
|
||||||
|
nlp = English()
|
||||||
|
nlp.add_pipe("tagger")
|
||||||
|
train_examples = []
|
||||||
|
for t in PARTIAL_DATA:
|
||||||
|
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
|
||||||
|
optimizer = nlp.initialize(get_examples=lambda: train_examples)
|
||||||
|
for i in range(50):
|
||||||
|
losses = {}
|
||||||
|
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
||||||
|
assert losses["tagger"] < 0.00001
|
||||||
|
|
||||||
|
# test the trained model
|
||||||
|
test_text = "I like blue eggs"
|
||||||
|
doc = nlp(test_text)
|
||||||
|
assert doc[1].tag_ is "V"
|
||||||
|
assert doc[2].tag_ is "J"
|
||||||
|
|
||||||
def test_overfitting_IO():
|
def test_overfitting_IO():
|
||||||
# Simple test to try and quickly overfit the tagger - ensuring the ML models work correctly
|
# Simple test to try and quickly overfit the tagger - ensuring the ML models work correctly
|
||||||
nlp = English()
|
nlp = English()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user