mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 09:57:26 +03:00 
			
		
		
		
	Add serialization tests for tagger
This commit is contained in:
		
							parent
							
								
									1b593bbd6d
								
							
						
					
					
						commit
						43b4d63f85
					
				
							
								
								
									
										39
									
								
								spacy/tests/serialize/test_serialize_tagger.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										39
									
								
								spacy/tests/serialize/test_serialize_tagger.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,39 @@
 | 
			
		|||
# coding: utf-8
 | 
			
		||||
from __future__ import unicode_literals
 | 
			
		||||
 | 
			
		||||
from ..util import make_tempdir
 | 
			
		||||
from ...pipeline import NeuralTagger as Tagger
 | 
			
		||||
 | 
			
		||||
import pytest
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.fixture
 | 
			
		||||
def taggers(en_vocab):
 | 
			
		||||
    tagger1 = Tagger(en_vocab, True)
 | 
			
		||||
    tagger2 = Tagger(en_vocab, True)
 | 
			
		||||
    tagger1.model = tagger1.Model(None, None)
 | 
			
		||||
    tagger2.model = tagger2.Model(None, None)
 | 
			
		||||
    return (tagger1, tagger2)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_serialize_tagger_roundtrip_bytes(en_vocab, taggers):
 | 
			
		||||
    tagger1, tagger2 = taggers
 | 
			
		||||
    tagger1_b = tagger1.to_bytes()
 | 
			
		||||
    tagger2_b = tagger2.to_bytes()
 | 
			
		||||
    assert tagger1_b == tagger2_b
 | 
			
		||||
    tagger1 = tagger1.from_bytes(tagger1_b)
 | 
			
		||||
    assert tagger1.to_bytes() == tagger1_b
 | 
			
		||||
    new_tagger1 = Tagger(en_vocab).from_bytes(tagger1_b)
 | 
			
		||||
    assert new_tagger1.to_bytes() == tagger1_b
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_serialize_tagger_roundtrip_disk(en_vocab, taggers):
 | 
			
		||||
    tagger1, tagger2 = taggers
 | 
			
		||||
    with make_tempdir() as d:
 | 
			
		||||
        file_path1 = d / 'tagger1'
 | 
			
		||||
        file_path2 = d / 'tagger2'
 | 
			
		||||
        tagger1.to_disk(file_path1)
 | 
			
		||||
        tagger2.to_disk(file_path2)
 | 
			
		||||
        tagger1_d = Tagger(en_vocab).from_disk(file_path1)
 | 
			
		||||
        tagger2_d = Tagger(en_vocab).from_disk(file_path2)
 | 
			
		||||
        assert tagger1_d.to_bytes() == tagger2_d.to_bytes()
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user