From 5052ed8cadbf1b96cad7d61e5850e9735e0a6243 Mon Sep 17 00:00:00 2001 From: richardpaulhudson Date: Wed, 14 Sep 2022 15:32:08 +0200 Subject: [PATCH] Fix case sensitivity --- spacy/tests/doc/test_doc_api.py | 8 +++++++- spacy/tokens/doc.pyx | 1 + 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/spacy/tests/doc/test_doc_api.py b/spacy/tests/doc/test_doc_api.py index 9630c895e..fa895919a 100644 --- a/spacy/tests/doc/test_doc_api.py +++ b/spacy/tests/doc/test_doc_api.py @@ -983,7 +983,7 @@ def test_doc_spans_setdefault(en_tokenizer): def test_get_affixes_good_case(en_tokenizer, case_sensitive): doc = en_tokenizer("spaCy✨ and Prodigy") prefixes = doc.get_affixes(False, case_sensitive, 1, 5, "", 2, 3) - suffixes = doc.get_affixes(True, case_sensitive, 2, 6, "xx✨rp", 2, 3) + suffixes = doc.get_affixes(True, case_sensitive, 2, 6, "xx✨rP", 2, 3) assert prefixes[0][3, 3, 3] == suffixes[0][3, 3, 3] assert prefixes[0][3, 3, 2] == suffixes[0][3, 3, 4] assert (prefixes[0][0, :, 1:] == 0).all() @@ -994,10 +994,16 @@ def test_get_affixes_good_case(en_tokenizer, case_sensitive): assert not (suffixes[0][1, :, 2:] == 0).all() assert (suffixes[0][1, :, 3:] == 0).all() assert suffixes[1][0][1].tolist() == [10024, 0] + if case_sensitive: + assert suffixes[1][0][3].tolist() == [114, 80] + else: + assert suffixes[1][0][3].tolist() == [114, 112] + suffixes = doc.get_affixes(True, case_sensitive, 2, 6, "xx✨rp", 2, 3) if case_sensitive: assert suffixes[1][0][3].tolist() == [114, 0] else: assert suffixes[1][0][3].tolist() == [114, 112] + def test_get_affixes_4_byte_normal_char(en_tokenizer): diff --git a/spacy/tokens/doc.pyx b/spacy/tokens/doc.pyx index c732cd9ee..2aea6d9b4 100644 --- a/spacy/tokens/doc.pyx +++ b/spacy/tokens/doc.pyx @@ -1755,6 +1755,7 @@ cdef class Doc: token_attrs = [t.orth_ for t in self] else: token_attrs = [t.lower_ for t in self] + special_chars = special_chars.lower() cdef unsigned int sc_len = len(special_chars) cdef const unsigned char[:] sc_bytes = get_utf16_memoryview(special_chars, True) cdef np.ndarray[np.uint16_t, ndim=1] scs = numpy.ndarray((sc_len,), buffer=sc_bytes, dtype="uint16")