mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 01:48:04 +03:00 
			
		
		
		
	Reproducibility for TextCat and Tok2Vec (#6218)
* ensure fixed seed in HashEmbed layers * forgot about the joys of python 2
This commit is contained in:
		
							parent
							
								
									9fc8392b38
								
							
						
					
					
						commit
						2998131416
					
				| 
						 | 
				
			
			@ -654,10 +654,10 @@ def build_text_classifier(nr_class, width=64, **cfg):
 | 
			
		|||
            )
 | 
			
		||||
            return model
 | 
			
		||||
 | 
			
		||||
        lower = HashEmbed(width, nr_vector, column=1)
 | 
			
		||||
        prefix = HashEmbed(width // 2, nr_vector, column=2)
 | 
			
		||||
        suffix = HashEmbed(width // 2, nr_vector, column=3)
 | 
			
		||||
        shape = HashEmbed(width // 2, nr_vector, column=4)
 | 
			
		||||
        lower = HashEmbed(width, nr_vector, column=1, seed=10)
 | 
			
		||||
        prefix = HashEmbed(width // 2, nr_vector, column=2, seed=11)
 | 
			
		||||
        suffix = HashEmbed(width // 2, nr_vector, column=3, seed=12)
 | 
			
		||||
        shape = HashEmbed(width // 2, nr_vector, column=4, seed=13)
 | 
			
		||||
 | 
			
		||||
        trained_vectors = FeatureExtracter(
 | 
			
		||||
            [ORTH, LOWER, PREFIX, SUFFIX, SHAPE, ID]
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -27,16 +27,16 @@ def Tok2Vec(width, embed_size, **kwargs):
 | 
			
		|||
    bilstm_depth = kwargs.get("bilstm_depth", 0)
 | 
			
		||||
    cols = [ID, NORM, PREFIX, SUFFIX, SHAPE, ORTH]
 | 
			
		||||
    with Model.define_operators({">>": chain, "|": concatenate, "**": clone}):
 | 
			
		||||
        norm = HashEmbed(width, embed_size, column=cols.index(NORM), name="embed_norm")
 | 
			
		||||
        norm = HashEmbed(width, embed_size, column=cols.index(NORM), name="embed_norm", seed=6)
 | 
			
		||||
        if subword_features:
 | 
			
		||||
            prefix = HashEmbed(
 | 
			
		||||
                width, embed_size // 2, column=cols.index(PREFIX), name="embed_prefix"
 | 
			
		||||
                width, embed_size // 2, column=cols.index(PREFIX), name="embed_prefix", seed=7
 | 
			
		||||
            )
 | 
			
		||||
            suffix = HashEmbed(
 | 
			
		||||
                width, embed_size // 2, column=cols.index(SUFFIX), name="embed_suffix"
 | 
			
		||||
                width, embed_size // 2, column=cols.index(SUFFIX), name="embed_suffix", seed=8
 | 
			
		||||
            )
 | 
			
		||||
            shape = HashEmbed(
 | 
			
		||||
                width, embed_size // 2, column=cols.index(SHAPE), name="embed_shape"
 | 
			
		||||
                width, embed_size // 2, column=cols.index(SHAPE), name="embed_shape", seed=9
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            prefix, suffix, shape = (None, None, None)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -42,16 +42,16 @@ def MultiHashEmbed(config):
 | 
			
		|||
    width = config["width"]
 | 
			
		||||
    rows = config["rows"]
 | 
			
		||||
 | 
			
		||||
    norm = HashEmbed(width, rows, column=cols.index("NORM"), name="embed_norm")
 | 
			
		||||
    norm = HashEmbed(width, rows, column=cols.index("NORM"), name="embed_norm", seed=1)
 | 
			
		||||
    if config["use_subwords"]:
 | 
			
		||||
        prefix = HashEmbed(
 | 
			
		||||
            width, rows // 2, column=cols.index("PREFIX"), name="embed_prefix"
 | 
			
		||||
            width, rows // 2, column=cols.index("PREFIX"), name="embed_prefix", seed=2
 | 
			
		||||
        )
 | 
			
		||||
        suffix = HashEmbed(
 | 
			
		||||
            width, rows // 2, column=cols.index("SUFFIX"), name="embed_suffix"
 | 
			
		||||
            width, rows // 2, column=cols.index("SUFFIX"), name="embed_suffix", seed=3
 | 
			
		||||
        )
 | 
			
		||||
        shape = HashEmbed(
 | 
			
		||||
            width, rows // 2, column=cols.index("SHAPE"), name="embed_shape"
 | 
			
		||||
            width, rows // 2, column=cols.index("SHAPE"), name="embed_shape", seed=4
 | 
			
		||||
        )
 | 
			
		||||
    if config.get("@pretrained_vectors"):
 | 
			
		||||
        glove = make_layer(config["@pretrained_vectors"])
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										35
									
								
								spacy/tests/regression/test_issue6177.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										35
									
								
								spacy/tests/regression/test_issue6177.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,35 @@
 | 
			
		|||
# coding: utf8
 | 
			
		||||
from __future__ import unicode_literals
 | 
			
		||||
 | 
			
		||||
from spacy.lang.en import English
 | 
			
		||||
from spacy.util import fix_random_seed
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_issue6177():
 | 
			
		||||
    """Test that after fixing the random seed, the results of the pipeline are truly identical"""
 | 
			
		||||
 | 
			
		||||
    # NOTE: no need to transform this code to v3 when 'master' is merged into 'develop'.
 | 
			
		||||
    # A similar test exists already for v3: test_issue5551
 | 
			
		||||
    # This is just a backport
 | 
			
		||||
 | 
			
		||||
    results = []
 | 
			
		||||
    for i in range(3):
 | 
			
		||||
        fix_random_seed(0)
 | 
			
		||||
        nlp = English()
 | 
			
		||||
        example = (
 | 
			
		||||
            "Once hot, form ping-pong-ball-sized balls of the mixture, each weighing roughly 25 g.",
 | 
			
		||||
            {"cats": {"Labe1": 1.0, "Label2": 0.0, "Label3": 0.0}},
 | 
			
		||||
        )
 | 
			
		||||
        textcat = nlp.create_pipe("textcat")
 | 
			
		||||
        nlp.add_pipe(textcat)
 | 
			
		||||
        for label in set(example[1]["cats"]):
 | 
			
		||||
            textcat.add_label(label)
 | 
			
		||||
        nlp.begin_training()
 | 
			
		||||
        # Store the result of each iteration
 | 
			
		||||
        result = textcat.model.predict([nlp.make_doc(example[0])])
 | 
			
		||||
        results.append(list(result[0]))
 | 
			
		||||
 | 
			
		||||
    # All results should be the same because of the fixed seed
 | 
			
		||||
    assert len(results) == 3
 | 
			
		||||
    assert results[0] == results[1]
 | 
			
		||||
    assert results[0] == results[2]
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user