mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 01:48:04 +03:00 
			
		
		
		
	
		
			
				
	
	
		
			114 lines
		
	
	
		
			3.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			114 lines
		
	
	
		
			3.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import pytest
 | 
						|
 | 
						|
import spacy
 | 
						|
from spacy.lang.en import English
 | 
						|
from spacy.tokens import Doc, DocBin
 | 
						|
from spacy.tokens.underscore import Underscore
 | 
						|
 | 
						|
 | 
						|
@pytest.mark.issue(4367)
 | 
						|
def test_issue4367():
 | 
						|
    """Test that docbin init goes well"""
 | 
						|
    DocBin()
 | 
						|
    DocBin(attrs=["LEMMA"])
 | 
						|
    DocBin(attrs=["LEMMA", "ENT_IOB", "ENT_TYPE"])
 | 
						|
 | 
						|
 | 
						|
@pytest.mark.issue(4528)
 | 
						|
def test_issue4528(en_vocab):
 | 
						|
    """Test that user_data is correctly serialized in DocBin."""
 | 
						|
    doc = Doc(en_vocab, words=["hello", "world"])
 | 
						|
    doc.user_data["foo"] = "bar"
 | 
						|
    # This is how extension attribute values are stored in the user data
 | 
						|
    doc.user_data[("._.", "foo", None, None)] = "bar"
 | 
						|
    doc_bin = DocBin(store_user_data=True)
 | 
						|
    doc_bin.add(doc)
 | 
						|
    doc_bin_bytes = doc_bin.to_bytes()
 | 
						|
    new_doc_bin = DocBin(store_user_data=True).from_bytes(doc_bin_bytes)
 | 
						|
    new_doc = list(new_doc_bin.get_docs(en_vocab))[0]
 | 
						|
    assert new_doc.user_data["foo"] == "bar"
 | 
						|
    assert new_doc.user_data[("._.", "foo", None, None)] == "bar"
 | 
						|
 | 
						|
 | 
						|
@pytest.mark.issue(5141)
 | 
						|
def test_issue5141(en_vocab):
 | 
						|
    """Ensure an empty DocBin does not crash on serialization"""
 | 
						|
    doc_bin = DocBin(attrs=["DEP", "HEAD"])
 | 
						|
    assert list(doc_bin.get_docs(en_vocab)) == []
 | 
						|
    doc_bin_bytes = doc_bin.to_bytes()
 | 
						|
    doc_bin_2 = DocBin().from_bytes(doc_bin_bytes)
 | 
						|
    assert list(doc_bin_2.get_docs(en_vocab)) == []
 | 
						|
 | 
						|
 | 
						|
def test_serialize_doc_bin():
 | 
						|
    doc_bin = DocBin(
 | 
						|
        attrs=["LEMMA", "ENT_IOB", "ENT_TYPE", "NORM", "ENT_ID"], store_user_data=True
 | 
						|
    )
 | 
						|
    texts = ["Some text", "Lots of texts...", "..."]
 | 
						|
    cats = {"A": 0.5}
 | 
						|
    nlp = English()
 | 
						|
    for doc in nlp.pipe(texts):
 | 
						|
        doc.cats = cats
 | 
						|
        span = doc[0:2]
 | 
						|
        span.label_ = "UNUSUAL_SPAN_LABEL"
 | 
						|
        span.id_ = "UNUSUAL_SPAN_ID"
 | 
						|
        span.kb_id_ = "UNUSUAL_SPAN_KB_ID"
 | 
						|
        doc.spans["start"] = [span]
 | 
						|
        doc[0].norm_ = "UNUSUAL_TOKEN_NORM"
 | 
						|
        doc[0].ent_id_ = "UNUSUAL_TOKEN_ENT_ID"
 | 
						|
        doc_bin.add(doc)
 | 
						|
    bytes_data = doc_bin.to_bytes()
 | 
						|
 | 
						|
    # Deserialize later, e.g. in a new process
 | 
						|
    nlp = spacy.blank("en")
 | 
						|
    doc_bin = DocBin().from_bytes(bytes_data)
 | 
						|
    reloaded_docs = list(doc_bin.get_docs(nlp.vocab))
 | 
						|
    for i, doc in enumerate(reloaded_docs):
 | 
						|
        assert doc.text == texts[i]
 | 
						|
        assert doc.cats == cats
 | 
						|
        assert len(doc.spans) == 1
 | 
						|
        assert doc.spans["start"][0].label_ == "UNUSUAL_SPAN_LABEL"
 | 
						|
        assert doc.spans["start"][0].id_ == "UNUSUAL_SPAN_ID"
 | 
						|
        assert doc.spans["start"][0].kb_id_ == "UNUSUAL_SPAN_KB_ID"
 | 
						|
        assert doc[0].norm_ == "UNUSUAL_TOKEN_NORM"
 | 
						|
        assert doc[0].ent_id_ == "UNUSUAL_TOKEN_ENT_ID"
 | 
						|
 | 
						|
 | 
						|
def test_serialize_doc_bin_unknown_spaces(en_vocab):
 | 
						|
    doc1 = Doc(en_vocab, words=["that", "'s"])
 | 
						|
    assert doc1.has_unknown_spaces
 | 
						|
    assert doc1.text == "that 's "
 | 
						|
    doc2 = Doc(en_vocab, words=["that", "'s"], spaces=[False, False])
 | 
						|
    assert not doc2.has_unknown_spaces
 | 
						|
    assert doc2.text == "that's"
 | 
						|
 | 
						|
    doc_bin = DocBin().from_bytes(DocBin(docs=[doc1, doc2]).to_bytes())
 | 
						|
    re_doc1, re_doc2 = doc_bin.get_docs(en_vocab)
 | 
						|
    assert re_doc1.has_unknown_spaces
 | 
						|
    assert re_doc1.text == "that 's "
 | 
						|
    assert not re_doc2.has_unknown_spaces
 | 
						|
    assert re_doc2.text == "that's"
 | 
						|
 | 
						|
 | 
						|
@pytest.mark.parametrize(
 | 
						|
    "writer_flag,reader_flag,reader_value",
 | 
						|
    [
 | 
						|
        (True, True, "bar"),
 | 
						|
        (True, False, "bar"),
 | 
						|
        (False, True, "nothing"),
 | 
						|
        (False, False, "nothing"),
 | 
						|
    ],
 | 
						|
)
 | 
						|
def test_serialize_custom_extension(en_vocab, writer_flag, reader_flag, reader_value):
 | 
						|
    """Test that custom extensions are correctly serialized in DocBin."""
 | 
						|
    Doc.set_extension("foo", default="nothing")
 | 
						|
    doc = Doc(en_vocab, words=["hello", "world"])
 | 
						|
    doc._.foo = "bar"
 | 
						|
    doc_bin_1 = DocBin(store_user_data=writer_flag)
 | 
						|
    doc_bin_1.add(doc)
 | 
						|
    doc_bin_bytes = doc_bin_1.to_bytes()
 | 
						|
    doc_bin_2 = DocBin(store_user_data=reader_flag).from_bytes(doc_bin_bytes)
 | 
						|
    doc_2 = list(doc_bin_2.get_docs(en_vocab))[0]
 | 
						|
    assert doc_2._.foo == reader_value
 | 
						|
    Underscore.doc_extensions = {}
 |