mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 17:06:29 +03:00
Prevent Tagger model init with 0 labels (#5984)
* Prevent Tagger model init with 0 labels Raise an error before trying to initialize a tagger model with 0 labels. * Add dummy tagger label for test * Remove tagless tagger model initializiation * Fix error number after merge * Add dummy tagger label to test * Fix formatting Co-authored-by: Matthew Honnibal <honnibal+gh@gmail.com>
This commit is contained in:
parent
c38298b8fa
commit
9130094199
|
@ -651,6 +651,7 @@ class Errors:
|
|||
E1005 = ("Unable to set attribute '{attr}' in tokenizer exception for "
|
||||
"'{chunk}'. Tokenizer exceptions are only allowed to specify "
|
||||
"`ORTH` and `NORM`.")
|
||||
E1006 = ("Unable to initialize {name} model with 0 labels.")
|
||||
|
||||
|
||||
@add_codes
|
||||
|
|
|
@ -285,11 +285,11 @@ class Tagger(Pipe):
|
|||
doc_sample.append(Doc(self.vocab, words=["hello"]))
|
||||
for tag in sorted(tags):
|
||||
self.add_label(tag)
|
||||
if len(self.labels) == 0:
|
||||
err = Errors.E1006.format(name="Tagger")
|
||||
raise ValueError(err)
|
||||
self.set_output(len(self.labels))
|
||||
if self.labels:
|
||||
self.model.initialize(X=doc_sample)
|
||||
else:
|
||||
self.model.initialize()
|
||||
self.model.initialize(X=doc_sample)
|
||||
if sgd is None:
|
||||
sgd = self.create_optimizer()
|
||||
return sgd
|
||||
|
|
|
@ -69,3 +69,10 @@ def test_overfitting_IO():
|
|||
assert doc2[1].tag_ is "V"
|
||||
assert doc2[2].tag_ is "J"
|
||||
assert doc2[3].tag_ is "N"
|
||||
|
||||
|
||||
def test_tagger_requires_labels():
|
||||
nlp = English()
|
||||
tagger = nlp.add_pipe("tagger")
|
||||
with pytest.raises(ValueError):
|
||||
optimizer = nlp.begin_training()
|
||||
|
|
|
@ -326,7 +326,8 @@ def test_issue4348():
|
|||
nlp = English()
|
||||
example = Example.from_dict(nlp.make_doc(""), {"tags": []})
|
||||
TRAIN_DATA = [example, example]
|
||||
nlp.add_pipe("tagger")
|
||||
tagger = nlp.add_pipe("tagger")
|
||||
tagger.add_label("A")
|
||||
optimizer = nlp.begin_training()
|
||||
for i in range(5):
|
||||
losses = {}
|
||||
|
|
|
@ -63,6 +63,7 @@ def tagger():
|
|||
# need to add model for two reasons:
|
||||
# 1. no model leads to error in serialization,
|
||||
# 2. the affected line is the one for model serialization
|
||||
tagger.add_label("A")
|
||||
tagger.begin_training(lambda: [], pipeline=nlp.pipeline)
|
||||
return tagger
|
||||
|
||||
|
|
|
@ -144,6 +144,7 @@ def test_serialize_nlp():
|
|||
""" Create a custom nlp pipeline from config and ensure it serializes it correctly """
|
||||
nlp_config = Config().from_str(nlp_config_string)
|
||||
nlp, _ = load_model_from_config(nlp_config, auto_fill=True)
|
||||
nlp.get_pipe("tagger").add_label("A")
|
||||
nlp.begin_training()
|
||||
assert "tok2vec" in nlp.pipe_names
|
||||
assert "tagger" in nlp.pipe_names
|
||||
|
|
Loading…
Reference in New Issue
Block a user