spaCy/spacy/tests/serialize/test_serialize_tokenizer.py
2023-06-26 11:41:03 +02:00

150 lines
5.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pickle
import re
import pytest
from spacy.attrs import ENT_IOB, ENT_TYPE
from spacy.lang.en import English
from spacy.tokenizer import Tokenizer
from spacy.tokens import Doc
from spacy.util import (
compile_infix_regex,
compile_prefix_regex,
compile_suffix_regex,
get_lang_class,
load_model,
)
from ..util import assert_packed_msg_equal, make_tempdir
def load_tokenizer(b):
tok = get_lang_class("en")().tokenizer
tok.from_bytes(b)
return tok
@pytest.mark.issue(2833)
def test_issue2833(en_vocab):
"""Test that a custom error is raised if a token or span is pickled."""
doc = Doc(en_vocab, words=["Hello", "world"])
with pytest.raises(NotImplementedError):
pickle.dumps(doc[0])
with pytest.raises(NotImplementedError):
pickle.dumps(doc[0:2])
@pytest.mark.issue(3012)
def test_issue3012(en_vocab):
"""Test that the is_tagged attribute doesn't get overwritten when we from_array
without tag information."""
words = ["This", "is", "10", "%", "."]
tags = ["DT", "VBZ", "CD", "NN", "."]
pos = ["DET", "VERB", "NUM", "NOUN", "PUNCT"]
ents = ["O", "O", "B-PERCENT", "I-PERCENT", "O"]
doc = Doc(en_vocab, words=words, tags=tags, pos=pos, ents=ents)
assert doc.has_annotation("TAG")
expected = ("10", "NUM", "CD", "PERCENT")
assert (doc[2].text, doc[2].pos_, doc[2].tag_, doc[2].ent_type_) == expected
header = [ENT_IOB, ENT_TYPE]
ent_array = doc.to_array(header)
doc.from_array(header, ent_array)
assert (doc[2].text, doc[2].pos_, doc[2].tag_, doc[2].ent_type_) == expected
# Serializing then deserializing
doc_bytes = doc.to_bytes()
doc2 = Doc(en_vocab).from_bytes(doc_bytes)
assert (doc2[2].text, doc2[2].pos_, doc2[2].tag_, doc2[2].ent_type_) == expected
@pytest.mark.issue(4190)
def test_issue4190():
def customize_tokenizer(nlp):
prefix_re = compile_prefix_regex(nlp.Defaults.prefixes)
suffix_re = compile_suffix_regex(nlp.Defaults.suffixes)
infix_re = compile_infix_regex(nlp.Defaults.infixes)
# Remove all exceptions where a single letter is followed by a period (e.g. 'h.')
exceptions = {
k: v
for k, v in dict(nlp.Defaults.tokenizer_exceptions).items()
if not (len(k) == 2 and k[1] == ".")
}
new_tokenizer = Tokenizer(
nlp.vocab,
exceptions,
prefix_search=prefix_re.search,
suffix_search=suffix_re.search,
infix_finditer=infix_re.finditer,
token_match=nlp.tokenizer.token_match,
faster_heuristics=False,
)
nlp.tokenizer = new_tokenizer
test_string = "Test c."
# Load default language
nlp_1 = English()
doc_1a = nlp_1(test_string)
result_1a = [token.text for token in doc_1a] # noqa: F841
# Modify tokenizer
customize_tokenizer(nlp_1)
doc_1b = nlp_1(test_string)
result_1b = [token.text for token in doc_1b]
# Save and Reload
with make_tempdir() as model_dir:
nlp_1.to_disk(model_dir)
nlp_2 = load_model(model_dir)
# This should be the modified tokenizer
doc_2 = nlp_2(test_string)
result_2 = [token.text for token in doc_2]
assert result_1b == result_2
assert nlp_2.tokenizer.faster_heuristics is False
def test_serialize_custom_tokenizer(en_vocab, en_tokenizer):
"""Test that custom tokenizer with not all functions defined or empty
properties can be serialized and deserialized correctly (see #2494,
#4991)."""
tokenizer = Tokenizer(en_vocab, suffix_search=en_tokenizer.suffix_search)
tokenizer_bytes = tokenizer.to_bytes()
Tokenizer(en_vocab).from_bytes(tokenizer_bytes)
# test that empty/unset values are set correctly on deserialization
tokenizer = get_lang_class("en")().tokenizer
tokenizer.token_match = re.compile("test").match
assert tokenizer.rules != {}
assert tokenizer.token_match is not None
assert tokenizer.url_match is not None
assert tokenizer.prefix_search is not None
assert tokenizer.infix_finditer is not None
tokenizer.from_bytes(tokenizer_bytes)
assert tokenizer.rules == {}
assert tokenizer.token_match is None
assert tokenizer.url_match is None
assert tokenizer.prefix_search is None
assert tokenizer.infix_finditer is None
tokenizer = Tokenizer(en_vocab, rules={"ABC.": [{"ORTH": "ABC"}, {"ORTH": "."}]})
tokenizer.rules = {}
tokenizer_bytes = tokenizer.to_bytes()
tokenizer_reloaded = Tokenizer(en_vocab).from_bytes(tokenizer_bytes)
assert tokenizer_reloaded.rules == {}
@pytest.mark.parametrize("text", ["I💜you", "theyre", "“hello”"])
def test_serialize_tokenizer_roundtrip_bytes(en_tokenizer, text):
tokenizer = en_tokenizer
new_tokenizer = load_tokenizer(tokenizer.to_bytes())
assert_packed_msg_equal(new_tokenizer.to_bytes(), tokenizer.to_bytes())
assert new_tokenizer.to_bytes() == tokenizer.to_bytes()
doc1 = tokenizer(text)
doc2 = new_tokenizer(text)
assert [token.text for token in doc1] == [token.text for token in doc2]
def test_serialize_tokenizer_roundtrip_disk(en_tokenizer):
tokenizer = en_tokenizer
with make_tempdir() as d:
file_path = d / "tokenizer"
tokenizer.to_disk(file_path)
tokenizer_d = en_tokenizer.from_disk(file_path)
assert tokenizer.to_bytes() == tokenizer_d.to_bytes()