Intermediate state

This commit is contained in:
richard@explosion.ai 2022-11-03 15:23:50 +01:00
parent b462f85a73
commit 7db2770c05
2 changed files with 67 additions and 61 deletions

View File

@ -4,54 +4,59 @@ import pytest
@pytest.mark.parametrize("case_sensitive", [True, False])
def test_get_search_char_byte_arrays_1_width_only(case_sensitive):
sc1, sc2, sc3, sc4 = spacy.util.get_search_char_byte_arrays("zzaaEP", case_sensitive)
search_chars, width_offsets = spacy.util.get_search_char_byte_arrays(
"zzaaEP", case_sensitive
)
if case_sensitive:
assert sc1 == b"EPaz"
assert search_chars == b"EPaz"
else:
assert sc1 == b"aepz"
assert sc2 == b""
assert sc3 == b""
assert sc4 == b""
assert search_chars == b"aepz"
assert width_offsets == b"\x00\x04\x04\x04\x04"
@pytest.mark.parametrize("case_sensitive", [True, False])
def test_get_search_char_byte_arrays_4_width_only(case_sensitive):
sc1, sc2, sc3, sc4 = spacy.util.get_search_char_byte_arrays("𐌞", case_sensitive)
assert sc1 == b""
assert sc2 == b""
assert sc3 == b""
assert sc4 == "𐌞".encode("utf-8")
search_chars, width_offsets = spacy.util.get_search_char_byte_arrays(
"𐌞", case_sensitive
)
assert search_chars == "𐌞".encode("utf-8")
assert width_offsets == b"\x00\x00\x00\x00\x04"
@pytest.mark.parametrize("case_sensitive", [True, False])
def test_get_search_char_byte_arrays_all_widths(case_sensitive):
sc1, sc2, sc3, sc4 = spacy.util.get_search_char_byte_arrays("𐌞Éabé—B𐌞", case_sensitive)
search_chars, width_offsets = spacy.util.get_search_char_byte_arrays(
"𐌞Éabé—B𐌞", case_sensitive
)
if case_sensitive:
assert sc1 == b"Bab"
assert sc2 == "Éé".encode("utf-8")
assert search_chars == "BabÉé—𐌞".encode("utf-8")
assert width_offsets == b"\x00\x03\x07\x0a\x0e"
else:
assert sc1 == b"ab"
assert sc2 == "é".encode("utf-8")
assert sc3 == "".encode("utf-8")
assert sc4 == "𐌞".encode("utf-8")
assert search_chars == "abé—𐌞".encode("utf-8")
assert width_offsets == b"\x00\x02\x04\x07\x0b"
@pytest.mark.parametrize("case_sensitive", [True, False])
def test_turkish_i_with_dot(case_sensitive):
sc1, sc2, sc3, sc4 = spacy.util.get_search_char_byte_arrays("İ", case_sensitive)
search_chars, width_offsets = spacy.util.get_search_char_byte_arrays(
"İ", case_sensitive
)
if case_sensitive:
assert sc2 == "İ".encode("utf-8")
assert sc1 == sc3 == sc4 == b""
assert search_chars == "İ".encode("utf-8")
assert width_offsets == b"\x00\x00\x02\x02\x02"
else:
assert sc1 == b"i"
assert sc2 == b"\xcc\x87"
assert sc3 == sc4 == b""
assert search_chars == b"i\xcc\x87"
assert width_offsets == b"\x00\x01\x03\x03\x03"
@pytest.mark.parametrize("case_sensitive", [True, False])
def test_turkish_i_with_dot_and_normal_i(case_sensitive):
sc1, sc2, sc3, sc4 = spacy.util.get_search_char_byte_arrays("İI", case_sensitive)
search_chars, width_offsets = spacy.util.get_search_char_byte_arrays(
"İI", case_sensitive
)
if case_sensitive:
assert sc1 == b"I"
assert sc2 == "İ".encode("utf-8")
assert sc3 == sc4 == b""
assert search_chars == "".encode("utf-8")
assert width_offsets == b"\x00\x01\x03\x03\x03"
else:
assert sc1 == b"i"
assert sc2 == b"\xcc\x87"
assert sc3 == sc4 == b""
assert search_chars == b"i\xcc\x87"
assert width_offsets == b"\x00\x01\x03\x03\x03"

View File

@ -1,3 +1,4 @@
from turtle import width
from typing import List, Mapping, NoReturn, Union, Dict, Any, Set, cast
from typing import Optional, Iterable, Callable, Tuple, Type
from typing import Iterator, Pattern, Generator, TYPE_CHECKING
@ -1738,41 +1739,41 @@ def all_equal(iterable):
def get_search_char_byte_arrays(
search_chars: str, case_sensitive: bool
) -> Tuple[bytes, bytes, bytes, bytes]:
search_char_string: str, case_sensitive: bool
) -> Tuple[bytes, bytes]:
"""
This function supports the rich feature extractor. It splits the UTF-8 representation
of *search_chars* into separate byte arrays containing 1-, 2-, 3-, and 4-byte
characters respectively. Any duplicates in *search_chars* are removed, and *search_chars*
is converted to lower case if *case_sensitive==False*.
This function supports the rich feature extractor. It orders the characters in
*search_char_string*, removing any duplicates, encodes them with UTF-8, and
returns the result together with a byte array containing the offsets where the
characters of various byte lengths start within the result, i.e.
<1-byte-start>, <2-byte-start>, <3-byte-start>, <4-byte-start>, <4-byte-end>.
If the string does not contain any characters of length *n*,
<n_byte_start> == <n+1_byte_start>.
"""
sc1 = bytearray()
sc2 = bytearray()
sc3 = bytearray()
sc4 = bytearray()
if not case_sensitive:
search_chars = search_chars.lower()
ordered_search_chars = "".join(sorted(set(search_chars)))
encoded_search_char_bytes = ordered_search_chars.encode("UTF-8")
working_start = 0
for idx in range(len(encoded_search_char_bytes) + 1):
if idx == 0:
continue
search_char_string = search_char_string.lower()
ordered_search_char_string = "".join(sorted(set(search_char_string)))
search_chars = ordered_search_char_string.encode("UTF-8")
width_offsets = [0, -1, -1, -1, -1]
working_start = -1
working_width = 1
for idx in range(len(search_chars) + 1):
if (
idx == len(encoded_search_char_bytes)
or encoded_search_char_bytes[idx] & 0xC0 != 0x80 # not continuation byte
idx == len(search_chars)
or search_chars[idx] & 0xC0 != 0x80 # not continuation byte
):
char_length = idx - working_start
if char_length == 1:
sc1.extend(encoded_search_char_bytes[working_start:idx])
elif char_length == 2:
sc2.extend(encoded_search_char_bytes[working_start:idx])
elif char_length == 3:
sc3.extend(encoded_search_char_bytes[working_start:idx])
elif char_length == 4:
sc4.extend(encoded_search_char_bytes[working_start:idx])
else:
this_width = idx - working_start
if this_width > 4 or this_width < working_width:
raise RuntimeError(Errors.E1050)
if this_width > working_width:
width_offsets[this_width - 1] = working_start
working_width = this_width
working_start = idx
return bytes(sc1), bytes(sc2), bytes(sc3), bytes(sc4)
width_offsets[this_width] = idx
for i in range(5):
if width_offsets[i] == -1:
width_offsets[i] = width_offsets[i - 1]
return search_chars, bytes((width_offsets))