mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 01:48:04 +03:00 
			
		
		
		
	Add xfail test for #3433. Improve test for add label.
This commit is contained in:
		
							parent
							
								
									6b6e9b638e
								
							
						
					
					
						commit
						444a3abfe5
					
				| 
						 | 
					@ -8,7 +8,8 @@ from spacy.attrs import NORM
 | 
				
			||||||
from spacy.gold import GoldParse
 | 
					from spacy.gold import GoldParse
 | 
				
			||||||
from spacy.vocab import Vocab
 | 
					from spacy.vocab import Vocab
 | 
				
			||||||
from spacy.tokens import Doc
 | 
					from spacy.tokens import Doc
 | 
				
			||||||
from spacy.pipeline import DependencyParser
 | 
					from spacy.pipeline import DependencyParser, EntityRecognizer
 | 
				
			||||||
 | 
					from spacy.util import fix_random_seed
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@pytest.fixture
 | 
					@pytest.fixture
 | 
				
			||||||
| 
						 | 
					@ -19,18 +20,6 @@ def vocab():
 | 
				
			||||||
@pytest.fixture
 | 
					@pytest.fixture
 | 
				
			||||||
def parser(vocab):
 | 
					def parser(vocab):
 | 
				
			||||||
    parser = DependencyParser(vocab)
 | 
					    parser = DependencyParser(vocab)
 | 
				
			||||||
    parser.cfg["token_vector_width"] = 8
 | 
					 | 
				
			||||||
    parser.cfg["hidden_width"] = 30
 | 
					 | 
				
			||||||
    parser.cfg["hist_size"] = 0
 | 
					 | 
				
			||||||
    parser.add_label("left")
 | 
					 | 
				
			||||||
    parser.begin_training([], **parser.cfg)
 | 
					 | 
				
			||||||
    sgd = Adam(NumpyOps(), 0.001)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    for i in range(10):
 | 
					 | 
				
			||||||
        losses = {}
 | 
					 | 
				
			||||||
        doc = Doc(vocab, words=["a", "b", "c", "d"])
 | 
					 | 
				
			||||||
        gold = GoldParse(doc, heads=[1, 1, 3, 3], deps=["left", "ROOT", "left", "ROOT"])
 | 
					 | 
				
			||||||
        parser.update([doc], [gold], sgd=sgd, losses=losses)
 | 
					 | 
				
			||||||
    return parser
 | 
					    return parser
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -38,10 +27,22 @@ def test_init_parser(parser):
 | 
				
			||||||
    pass
 | 
					    pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# TODO: This is flakey, because it depends on what the parser first learns.
 | 
					def _train_parser(parser):
 | 
				
			||||||
# TODO: This now seems to be implicated in segfaults. Not sure what's up!
 | 
					    fix_random_seed(1)
 | 
				
			||||||
@pytest.mark.skip
 | 
					    parser.add_label("left")
 | 
				
			||||||
 | 
					    parser.begin_training([], **parser.cfg)
 | 
				
			||||||
 | 
					    sgd = Adam(NumpyOps(), 0.001)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for i in range(10):
 | 
				
			||||||
 | 
					        losses = {}
 | 
				
			||||||
 | 
					        doc = Doc(parser.vocab, words=["a", "b", "c", "d"])
 | 
				
			||||||
 | 
					        gold = GoldParse(doc, heads=[1, 1, 3, 3], deps=["left", "ROOT", "left", "ROOT"])
 | 
				
			||||||
 | 
					        parser.update([doc], [gold], sgd=sgd, losses=losses)
 | 
				
			||||||
 | 
					    return parser
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_add_label(parser):
 | 
					def test_add_label(parser):
 | 
				
			||||||
 | 
					    parser = _train_parser(parser)
 | 
				
			||||||
    doc = Doc(parser.vocab, words=["a", "b", "c", "d"])
 | 
					    doc = Doc(parser.vocab, words=["a", "b", "c", "d"])
 | 
				
			||||||
    doc = parser(doc)
 | 
					    doc = parser(doc)
 | 
				
			||||||
    assert doc[0].head.i == 1
 | 
					    assert doc[0].head.i == 1
 | 
				
			||||||
| 
						 | 
					@ -69,3 +70,16 @@ def test_add_label(parser):
 | 
				
			||||||
    doc = parser(doc)
 | 
					    doc = parser(doc)
 | 
				
			||||||
    assert doc[0].dep_ == "right"
 | 
					    assert doc[0].dep_ == "right"
 | 
				
			||||||
    assert doc[2].dep_ == "left"
 | 
					    assert doc[2].dep_ == "left"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.xfail
 | 
				
			||||||
 | 
					def test_add_label_deserializes_correctly():
 | 
				
			||||||
 | 
					    ner1 = EntityRecognizer(Vocab())
 | 
				
			||||||
 | 
					    ner1.add_label("C")
 | 
				
			||||||
 | 
					    ner1.add_label("B")
 | 
				
			||||||
 | 
					    ner1.add_label("A")
 | 
				
			||||||
 | 
					    ner1.begin_training([])
 | 
				
			||||||
 | 
					    ner2 = EntityRecognizer(Vocab()).from_bytes(ner1.to_bytes())
 | 
				
			||||||
 | 
					    assert ner1.moves.n_moves == ner2.moves.n_moves
 | 
				
			||||||
 | 
					    for i in range(ner1.moves.n_moves):
 | 
				
			||||||
 | 
					        assert ner1.moves.get_class_name(i) == ner2.moves.get_class_name(i)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user