Intermediate state

This commit is contained in:
richard@explosion.ai 2022-11-03 15:23:50 +01:00
parent b462f85a73
commit 7db2770c05
2 changed files with 67 additions and 61 deletions

View File

@ -4,54 +4,59 @@ import pytest
@pytest.mark.parametrize("case_sensitive", [True, False]) @pytest.mark.parametrize("case_sensitive", [True, False])
def test_get_search_char_byte_arrays_1_width_only(case_sensitive): def test_get_search_char_byte_arrays_1_width_only(case_sensitive):
sc1, sc2, sc3, sc4 = spacy.util.get_search_char_byte_arrays("zzaaEP", case_sensitive) search_chars, width_offsets = spacy.util.get_search_char_byte_arrays(
"zzaaEP", case_sensitive
)
if case_sensitive: if case_sensitive:
assert sc1 == b"EPaz" assert search_chars == b"EPaz"
else: else:
assert sc1 == b"aepz" assert search_chars == b"aepz"
assert sc2 == b"" assert width_offsets == b"\x00\x04\x04\x04\x04"
assert sc3 == b""
assert sc4 == b""
@pytest.mark.parametrize("case_sensitive", [True, False]) @pytest.mark.parametrize("case_sensitive", [True, False])
def test_get_search_char_byte_arrays_4_width_only(case_sensitive): def test_get_search_char_byte_arrays_4_width_only(case_sensitive):
sc1, sc2, sc3, sc4 = spacy.util.get_search_char_byte_arrays("𐌞", case_sensitive) search_chars, width_offsets = spacy.util.get_search_char_byte_arrays(
assert sc1 == b"" "𐌞", case_sensitive
assert sc2 == b"" )
assert sc3 == b"" assert search_chars == "𐌞".encode("utf-8")
assert sc4 == "𐌞".encode("utf-8") assert width_offsets == b"\x00\x00\x00\x00\x04"
@pytest.mark.parametrize("case_sensitive", [True, False]) @pytest.mark.parametrize("case_sensitive", [True, False])
def test_get_search_char_byte_arrays_all_widths(case_sensitive): def test_get_search_char_byte_arrays_all_widths(case_sensitive):
sc1, sc2, sc3, sc4 = spacy.util.get_search_char_byte_arrays("𐌞Éabé—B𐌞", case_sensitive) search_chars, width_offsets = spacy.util.get_search_char_byte_arrays(
"𐌞Éabé—B𐌞", case_sensitive
)
if case_sensitive: if case_sensitive:
assert sc1 == b"Bab" assert search_chars == "BabÉé—𐌞".encode("utf-8")
assert sc2 == "Éé".encode("utf-8") assert width_offsets == b"\x00\x03\x07\x0a\x0e"
else: else:
assert sc1 == b"ab" assert search_chars == "abé—𐌞".encode("utf-8")
assert sc2 == "é".encode("utf-8") assert width_offsets == b"\x00\x02\x04\x07\x0b"
assert sc3 == "".encode("utf-8")
assert sc4 == "𐌞".encode("utf-8")
@pytest.mark.parametrize("case_sensitive", [True, False]) @pytest.mark.parametrize("case_sensitive", [True, False])
def test_turkish_i_with_dot(case_sensitive): def test_turkish_i_with_dot(case_sensitive):
sc1, sc2, sc3, sc4 = spacy.util.get_search_char_byte_arrays("İ", case_sensitive) search_chars, width_offsets = spacy.util.get_search_char_byte_arrays(
"İ", case_sensitive
)
if case_sensitive: if case_sensitive:
assert sc2 == "İ".encode("utf-8") assert search_chars == "İ".encode("utf-8")
assert sc1 == sc3 == sc4 == b"" assert width_offsets == b"\x00\x00\x02\x02\x02"
else: else:
assert sc1 == b"i" assert search_chars == b"i\xcc\x87"
assert sc2 == b"\xcc\x87" assert width_offsets == b"\x00\x01\x03\x03\x03"
assert sc3 == sc4 == b""
@pytest.mark.parametrize("case_sensitive", [True, False]) @pytest.mark.parametrize("case_sensitive", [True, False])
def test_turkish_i_with_dot_and_normal_i(case_sensitive): def test_turkish_i_with_dot_and_normal_i(case_sensitive):
sc1, sc2, sc3, sc4 = spacy.util.get_search_char_byte_arrays("İI", case_sensitive) search_chars, width_offsets = spacy.util.get_search_char_byte_arrays(
"İI", case_sensitive
)
if case_sensitive: if case_sensitive:
assert sc1 == b"I" assert search_chars == "".encode("utf-8")
assert sc2 == "İ".encode("utf-8") assert width_offsets == b"\x00\x01\x03\x03\x03"
assert sc3 == sc4 == b""
else: else:
assert sc1 == b"i" assert search_chars == b"i\xcc\x87"
assert sc2 == b"\xcc\x87" assert width_offsets == b"\x00\x01\x03\x03\x03"
assert sc3 == sc4 == b""

View File

@ -1,3 +1,4 @@
from turtle import width
from typing import List, Mapping, NoReturn, Union, Dict, Any, Set, cast from typing import List, Mapping, NoReturn, Union, Dict, Any, Set, cast
from typing import Optional, Iterable, Callable, Tuple, Type from typing import Optional, Iterable, Callable, Tuple, Type
from typing import Iterator, Pattern, Generator, TYPE_CHECKING from typing import Iterator, Pattern, Generator, TYPE_CHECKING
@ -1738,41 +1739,41 @@ def all_equal(iterable):
def get_search_char_byte_arrays( def get_search_char_byte_arrays(
search_chars: str, case_sensitive: bool search_char_string: str, case_sensitive: bool
) -> Tuple[bytes, bytes, bytes, bytes]: ) -> Tuple[bytes, bytes]:
""" """
This function supports the rich feature extractor. It splits the UTF-8 representation This function supports the rich feature extractor. It orders the characters in
of *search_chars* into separate byte arrays containing 1-, 2-, 3-, and 4-byte *search_char_string*, removing any duplicates, encodes them with UTF-8, and
characters respectively. Any duplicates in *search_chars* are removed, and *search_chars* returns the result together with a byte array containing the offsets where the
is converted to lower case if *case_sensitive==False*. characters of various byte lengths start within the result, i.e.
<1-byte-start>, <2-byte-start>, <3-byte-start>, <4-byte-start>, <4-byte-end>.
If the string does not contain any characters of length *n*,
<n_byte_start> == <n+1_byte_start>.
""" """
sc1 = bytearray()
sc2 = bytearray()
sc3 = bytearray()
sc4 = bytearray()
if not case_sensitive: if not case_sensitive:
search_chars = search_chars.lower() search_char_string = search_char_string.lower()
ordered_search_chars = "".join(sorted(set(search_chars))) ordered_search_char_string = "".join(sorted(set(search_char_string)))
encoded_search_char_bytes = ordered_search_chars.encode("UTF-8") search_chars = ordered_search_char_string.encode("UTF-8")
working_start = 0 width_offsets = [0, -1, -1, -1, -1]
for idx in range(len(encoded_search_char_bytes) + 1): working_start = -1
if idx == 0: working_width = 1
continue for idx in range(len(search_chars) + 1):
if ( if (
idx == len(encoded_search_char_bytes) idx == len(search_chars)
or encoded_search_char_bytes[idx] & 0xC0 != 0x80 # not continuation byte or search_chars[idx] & 0xC0 != 0x80 # not continuation byte
): ):
char_length = idx - working_start this_width = idx - working_start
if char_length == 1: if this_width > 4 or this_width < working_width:
sc1.extend(encoded_search_char_bytes[working_start:idx])
elif char_length == 2:
sc2.extend(encoded_search_char_bytes[working_start:idx])
elif char_length == 3:
sc3.extend(encoded_search_char_bytes[working_start:idx])
elif char_length == 4:
sc4.extend(encoded_search_char_bytes[working_start:idx])
else:
raise RuntimeError(Errors.E1050) raise RuntimeError(Errors.E1050)
if this_width > working_width:
width_offsets[this_width - 1] = working_start
working_width = this_width
working_start = idx working_start = idx
return bytes(sc1), bytes(sc2), bytes(sc3), bytes(sc4) width_offsets[this_width] = idx
for i in range(5):
if width_offsets[i] == -1:
width_offsets[i] = width_offsets[i - 1]
return search_chars, bytes((width_offsets))