mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-01 00:17:44 +03:00 
			
		
		
		
	* Draft out initial Spans data structure * Initial span group commit * Basic span group support on Doc * Basic test for span group * Compile span_group.pyx * Draft addition of SpanGroup to DocBin * Add deserialization for SpanGroup * Add tests for serializing SpanGroup * Fix serialization of SpanGroup * Add EdgeC and GraphC structs * Add draft Graph data structure * Compile graph * More work on Graph * Update GraphC * Upd graph * Fix walk functions * Let Graph take nodes and edges on construction * Fix walking and getting * Add graph tests * Fix import * Add module with the SpanGroups dict thingy * Update test * Rename 'span_groups' attribute * Try to fix c++11 compilation * Fix test * Update DocBin * Try to fix compilation * Try to fix graph * Improve SpanGroup docstrings * Add doc.spans to documentation * Fix serialization * Tidy up and add docs * Update docs [ci skip] * Add SpanGroup.has_overlap * WIP updated Graph API * Start testing new Graph API * Update Graph tests * Update Graph * Add docstring Co-authored-by: Ines Montani <ines@ines.io>
		
			
				
	
	
		
			124 lines
		
	
	
		
			4.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			124 lines
		
	
	
		
			4.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import pytest
 | |
| from spacy.tokens.doc import Underscore
 | |
| 
 | |
| import spacy
 | |
| from spacy.lang.en import English
 | |
| from spacy.tokens import Doc, DocBin
 | |
| 
 | |
| from ..util import make_tempdir
 | |
| 
 | |
| 
 | |
| def test_serialize_empty_doc(en_vocab):
 | |
|     doc = Doc(en_vocab)
 | |
|     data = doc.to_bytes()
 | |
|     doc2 = Doc(en_vocab)
 | |
|     doc2.from_bytes(data)
 | |
|     assert len(doc) == len(doc2)
 | |
|     for token1, token2 in zip(doc, doc2):
 | |
|         assert token1.text == token2.text
 | |
| 
 | |
| 
 | |
| def test_serialize_doc_roundtrip_bytes(en_vocab):
 | |
|     doc = Doc(en_vocab, words=["hello", "world"])
 | |
|     doc.cats = {"A": 0.5}
 | |
|     doc_b = doc.to_bytes()
 | |
|     new_doc = Doc(en_vocab).from_bytes(doc_b)
 | |
|     assert new_doc.to_bytes() == doc_b
 | |
| 
 | |
| 
 | |
| def test_serialize_doc_roundtrip_disk(en_vocab):
 | |
|     doc = Doc(en_vocab, words=["hello", "world"])
 | |
|     with make_tempdir() as d:
 | |
|         file_path = d / "doc"
 | |
|         doc.to_disk(file_path)
 | |
|         doc_d = Doc(en_vocab).from_disk(file_path)
 | |
|         assert doc.to_bytes() == doc_d.to_bytes()
 | |
| 
 | |
| 
 | |
| def test_serialize_doc_roundtrip_disk_str_path(en_vocab):
 | |
|     doc = Doc(en_vocab, words=["hello", "world"])
 | |
|     with make_tempdir() as d:
 | |
|         file_path = d / "doc"
 | |
|         file_path = str(file_path)
 | |
|         doc.to_disk(file_path)
 | |
|         doc_d = Doc(en_vocab).from_disk(file_path)
 | |
|         assert doc.to_bytes() == doc_d.to_bytes()
 | |
| 
 | |
| 
 | |
| def test_serialize_doc_exclude(en_vocab):
 | |
|     doc = Doc(en_vocab, words=["hello", "world"])
 | |
|     doc.user_data["foo"] = "bar"
 | |
|     new_doc = Doc(en_vocab).from_bytes(doc.to_bytes())
 | |
|     assert new_doc.user_data["foo"] == "bar"
 | |
|     new_doc = Doc(en_vocab).from_bytes(doc.to_bytes(), exclude=["user_data"])
 | |
|     assert not new_doc.user_data
 | |
|     new_doc = Doc(en_vocab).from_bytes(doc.to_bytes(exclude=["user_data"]))
 | |
|     assert not new_doc.user_data
 | |
| 
 | |
| 
 | |
| def test_serialize_doc_span_groups(en_vocab):
 | |
|     doc = Doc(en_vocab, words=["hello", "world", "!"])
 | |
|     doc.spans["content"] = [doc[0:2]]
 | |
|     new_doc = Doc(en_vocab).from_bytes(doc.to_bytes())
 | |
|     assert len(new_doc.spans["content"]) == 1
 | |
| 
 | |
| 
 | |
| def test_serialize_doc_bin():
 | |
|     doc_bin = DocBin(attrs=["LEMMA", "ENT_IOB", "ENT_TYPE"], store_user_data=True)
 | |
|     texts = ["Some text", "Lots of texts...", "..."]
 | |
|     cats = {"A": 0.5}
 | |
|     nlp = English()
 | |
|     for doc in nlp.pipe(texts):
 | |
|         doc.cats = cats
 | |
|         doc.spans["start"] = [doc[0:2]]
 | |
|         doc_bin.add(doc)
 | |
|     bytes_data = doc_bin.to_bytes()
 | |
| 
 | |
|     # Deserialize later, e.g. in a new process
 | |
|     nlp = spacy.blank("en")
 | |
|     doc_bin = DocBin().from_bytes(bytes_data)
 | |
|     reloaded_docs = list(doc_bin.get_docs(nlp.vocab))
 | |
|     for i, doc in enumerate(reloaded_docs):
 | |
|         assert doc.text == texts[i]
 | |
|         assert doc.cats == cats
 | |
|         assert len(doc.spans) == 1
 | |
| 
 | |
| 
 | |
| def test_serialize_doc_bin_unknown_spaces(en_vocab):
 | |
|     doc1 = Doc(en_vocab, words=["that", "'s"])
 | |
|     assert doc1.has_unknown_spaces
 | |
|     assert doc1.text == "that 's "
 | |
|     doc2 = Doc(en_vocab, words=["that", "'s"], spaces=[False, False])
 | |
|     assert not doc2.has_unknown_spaces
 | |
|     assert doc2.text == "that's"
 | |
| 
 | |
|     doc_bin = DocBin().from_bytes(DocBin(docs=[doc1, doc2]).to_bytes())
 | |
|     re_doc1, re_doc2 = doc_bin.get_docs(en_vocab)
 | |
|     assert re_doc1.has_unknown_spaces
 | |
|     assert re_doc1.text == "that 's "
 | |
|     assert not re_doc2.has_unknown_spaces
 | |
|     assert re_doc2.text == "that's"
 | |
| 
 | |
| 
 | |
| @pytest.mark.parametrize(
 | |
|     "writer_flag,reader_flag,reader_value",
 | |
|     [
 | |
|         (True, True, "bar"),
 | |
|         (True, False, "bar"),
 | |
|         (False, True, "nothing"),
 | |
|         (False, False, "nothing"),
 | |
|     ],
 | |
| )
 | |
| def test_serialize_custom_extension(en_vocab, writer_flag, reader_flag, reader_value):
 | |
|     """Test that custom extensions are correctly serialized in DocBin."""
 | |
|     Doc.set_extension("foo", default="nothing")
 | |
|     doc = Doc(en_vocab, words=["hello", "world"])
 | |
|     doc._.foo = "bar"
 | |
|     doc_bin_1 = DocBin(store_user_data=writer_flag)
 | |
|     doc_bin_1.add(doc)
 | |
|     doc_bin_bytes = doc_bin_1.to_bytes()
 | |
|     doc_bin_2 = DocBin(store_user_data=reader_flag).from_bytes(doc_bin_bytes)
 | |
|     doc_2 = list(doc_bin_2.get_docs(en_vocab))[0]
 | |
|     assert doc_2._.foo == reader_value
 | |
|     Underscore.doc_extensions = {}
 |