Prevent Tagger model init with 0 labels (#5984)

* Prevent Tagger model init with 0 labels

Raise an error before trying to initialize a tagger model with 0 labels.

* Add dummy tagger label for test

* Remove tagless tagger model initializiation

* Fix error number after merge

* Add dummy tagger label to test

* Fix formatting

Co-authored-by: Matthew Honnibal <honnibal+gh@gmail.com>
This commit is contained in:
Adriane Boyd 2020-08-31 21:24:33 +02:00 committed by GitHub
parent c38298b8fa
commit 9130094199
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 16 additions and 5 deletions

View File

@ -651,6 +651,7 @@ class Errors:
E1005 = ("Unable to set attribute '{attr}' in tokenizer exception for "
"'{chunk}'. Tokenizer exceptions are only allowed to specify "
"`ORTH` and `NORM`.")
E1006 = ("Unable to initialize {name} model with 0 labels.")
@add_codes

View File

@ -285,11 +285,11 @@ class Tagger(Pipe):
doc_sample.append(Doc(self.vocab, words=["hello"]))
for tag in sorted(tags):
self.add_label(tag)
if len(self.labels) == 0:
err = Errors.E1006.format(name="Tagger")
raise ValueError(err)
self.set_output(len(self.labels))
if self.labels:
self.model.initialize(X=doc_sample)
else:
self.model.initialize()
self.model.initialize(X=doc_sample)
if sgd is None:
sgd = self.create_optimizer()
return sgd

View File

@ -69,3 +69,10 @@ def test_overfitting_IO():
assert doc2[1].tag_ is "V"
assert doc2[2].tag_ is "J"
assert doc2[3].tag_ is "N"
def test_tagger_requires_labels():
nlp = English()
tagger = nlp.add_pipe("tagger")
with pytest.raises(ValueError):
optimizer = nlp.begin_training()

View File

@ -326,7 +326,8 @@ def test_issue4348():
nlp = English()
example = Example.from_dict(nlp.make_doc(""), {"tags": []})
TRAIN_DATA = [example, example]
nlp.add_pipe("tagger")
tagger = nlp.add_pipe("tagger")
tagger.add_label("A")
optimizer = nlp.begin_training()
for i in range(5):
losses = {}

View File

@ -63,6 +63,7 @@ def tagger():
# need to add model for two reasons:
# 1. no model leads to error in serialization,
# 2. the affected line is the one for model serialization
tagger.add_label("A")
tagger.begin_training(lambda: [], pipeline=nlp.pipeline)
return tagger

View File

@ -144,6 +144,7 @@ def test_serialize_nlp():
""" Create a custom nlp pipeline from config and ensure it serializes it correctly """
nlp_config = Config().from_str(nlp_config_string)
nlp, _ = load_model_from_config(nlp_config, auto_fill=True)
nlp.get_pipe("tagger").add_label("A")
nlp.begin_training()
assert "tok2vec" in nlp.pipe_names
assert "tagger" in nlp.pipe_names