mirror of
https://github.com/explosion/spaCy.git
synced 2025-06-02 20:23:12 +03:00
Prevent Tagger model init with 0 labels (#5984)
* Prevent Tagger model init with 0 labels Raise an error before trying to initialize a tagger model with 0 labels. * Add dummy tagger label for test * Remove tagless tagger model initializiation * Fix error number after merge * Add dummy tagger label to test * Fix formatting Co-authored-by: Matthew Honnibal <honnibal+gh@gmail.com>
This commit is contained in:
parent
c38298b8fa
commit
9130094199
|
@ -651,6 +651,7 @@ class Errors:
|
||||||
E1005 = ("Unable to set attribute '{attr}' in tokenizer exception for "
|
E1005 = ("Unable to set attribute '{attr}' in tokenizer exception for "
|
||||||
"'{chunk}'. Tokenizer exceptions are only allowed to specify "
|
"'{chunk}'. Tokenizer exceptions are only allowed to specify "
|
||||||
"`ORTH` and `NORM`.")
|
"`ORTH` and `NORM`.")
|
||||||
|
E1006 = ("Unable to initialize {name} model with 0 labels.")
|
||||||
|
|
||||||
|
|
||||||
@add_codes
|
@add_codes
|
||||||
|
|
|
@ -285,11 +285,11 @@ class Tagger(Pipe):
|
||||||
doc_sample.append(Doc(self.vocab, words=["hello"]))
|
doc_sample.append(Doc(self.vocab, words=["hello"]))
|
||||||
for tag in sorted(tags):
|
for tag in sorted(tags):
|
||||||
self.add_label(tag)
|
self.add_label(tag)
|
||||||
|
if len(self.labels) == 0:
|
||||||
|
err = Errors.E1006.format(name="Tagger")
|
||||||
|
raise ValueError(err)
|
||||||
self.set_output(len(self.labels))
|
self.set_output(len(self.labels))
|
||||||
if self.labels:
|
|
||||||
self.model.initialize(X=doc_sample)
|
self.model.initialize(X=doc_sample)
|
||||||
else:
|
|
||||||
self.model.initialize()
|
|
||||||
if sgd is None:
|
if sgd is None:
|
||||||
sgd = self.create_optimizer()
|
sgd = self.create_optimizer()
|
||||||
return sgd
|
return sgd
|
||||||
|
|
|
@ -69,3 +69,10 @@ def test_overfitting_IO():
|
||||||
assert doc2[1].tag_ is "V"
|
assert doc2[1].tag_ is "V"
|
||||||
assert doc2[2].tag_ is "J"
|
assert doc2[2].tag_ is "J"
|
||||||
assert doc2[3].tag_ is "N"
|
assert doc2[3].tag_ is "N"
|
||||||
|
|
||||||
|
|
||||||
|
def test_tagger_requires_labels():
|
||||||
|
nlp = English()
|
||||||
|
tagger = nlp.add_pipe("tagger")
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
optimizer = nlp.begin_training()
|
||||||
|
|
|
@ -326,7 +326,8 @@ def test_issue4348():
|
||||||
nlp = English()
|
nlp = English()
|
||||||
example = Example.from_dict(nlp.make_doc(""), {"tags": []})
|
example = Example.from_dict(nlp.make_doc(""), {"tags": []})
|
||||||
TRAIN_DATA = [example, example]
|
TRAIN_DATA = [example, example]
|
||||||
nlp.add_pipe("tagger")
|
tagger = nlp.add_pipe("tagger")
|
||||||
|
tagger.add_label("A")
|
||||||
optimizer = nlp.begin_training()
|
optimizer = nlp.begin_training()
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
losses = {}
|
losses = {}
|
||||||
|
|
|
@ -63,6 +63,7 @@ def tagger():
|
||||||
# need to add model for two reasons:
|
# need to add model for two reasons:
|
||||||
# 1. no model leads to error in serialization,
|
# 1. no model leads to error in serialization,
|
||||||
# 2. the affected line is the one for model serialization
|
# 2. the affected line is the one for model serialization
|
||||||
|
tagger.add_label("A")
|
||||||
tagger.begin_training(lambda: [], pipeline=nlp.pipeline)
|
tagger.begin_training(lambda: [], pipeline=nlp.pipeline)
|
||||||
return tagger
|
return tagger
|
||||||
|
|
||||||
|
|
|
@ -144,6 +144,7 @@ def test_serialize_nlp():
|
||||||
""" Create a custom nlp pipeline from config and ensure it serializes it correctly """
|
""" Create a custom nlp pipeline from config and ensure it serializes it correctly """
|
||||||
nlp_config = Config().from_str(nlp_config_string)
|
nlp_config = Config().from_str(nlp_config_string)
|
||||||
nlp, _ = load_model_from_config(nlp_config, auto_fill=True)
|
nlp, _ = load_model_from_config(nlp_config, auto_fill=True)
|
||||||
|
nlp.get_pipe("tagger").add_label("A")
|
||||||
nlp.begin_training()
|
nlp.begin_training()
|
||||||
assert "tok2vec" in nlp.pipe_names
|
assert "tok2vec" in nlp.pipe_names
|
||||||
assert "tagger" in nlp.pipe_names
|
assert "tagger" in nlp.pipe_names
|
||||||
|
|
Loading…
Reference in New Issue
Block a user