mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-02 11:20:19 +03:00
Intermediate state
This commit is contained in:
parent
b462f85a73
commit
7db2770c05
|
@ -4,54 +4,59 @@ import pytest
|
|||
|
||||
@pytest.mark.parametrize("case_sensitive", [True, False])
|
||||
def test_get_search_char_byte_arrays_1_width_only(case_sensitive):
|
||||
sc1, sc2, sc3, sc4 = spacy.util.get_search_char_byte_arrays("zzaaEP", case_sensitive)
|
||||
search_chars, width_offsets = spacy.util.get_search_char_byte_arrays(
|
||||
"zzaaEP", case_sensitive
|
||||
)
|
||||
if case_sensitive:
|
||||
assert sc1 == b"EPaz"
|
||||
assert search_chars == b"EPaz"
|
||||
else:
|
||||
assert sc1 == b"aepz"
|
||||
assert sc2 == b""
|
||||
assert sc3 == b""
|
||||
assert sc4 == b""
|
||||
assert search_chars == b"aepz"
|
||||
assert width_offsets == b"\x00\x04\x04\x04\x04"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("case_sensitive", [True, False])
|
||||
def test_get_search_char_byte_arrays_4_width_only(case_sensitive):
|
||||
sc1, sc2, sc3, sc4 = spacy.util.get_search_char_byte_arrays("𐌞", case_sensitive)
|
||||
assert sc1 == b""
|
||||
assert sc2 == b""
|
||||
assert sc3 == b""
|
||||
assert sc4 == "𐌞".encode("utf-8")
|
||||
search_chars, width_offsets = spacy.util.get_search_char_byte_arrays(
|
||||
"𐌞", case_sensitive
|
||||
)
|
||||
assert search_chars == "𐌞".encode("utf-8")
|
||||
assert width_offsets == b"\x00\x00\x00\x00\x04"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("case_sensitive", [True, False])
|
||||
def test_get_search_char_byte_arrays_all_widths(case_sensitive):
|
||||
sc1, sc2, sc3, sc4 = spacy.util.get_search_char_byte_arrays("𐌞Éabé—B𐌞", case_sensitive)
|
||||
search_chars, width_offsets = spacy.util.get_search_char_byte_arrays(
|
||||
"𐌞Éabé—B𐌞", case_sensitive
|
||||
)
|
||||
if case_sensitive:
|
||||
assert sc1 == b"Bab"
|
||||
assert sc2 == "Éé".encode("utf-8")
|
||||
assert search_chars == "BabÉé—𐌞".encode("utf-8")
|
||||
assert width_offsets == b"\x00\x03\x07\x0a\x0e"
|
||||
else:
|
||||
assert sc1 == b"ab"
|
||||
assert sc2 == "é".encode("utf-8")
|
||||
assert sc3 == "—".encode("utf-8")
|
||||
assert sc4 == "𐌞".encode("utf-8")
|
||||
assert search_chars == "abé—𐌞".encode("utf-8")
|
||||
assert width_offsets == b"\x00\x02\x04\x07\x0b"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("case_sensitive", [True, False])
|
||||
def test_turkish_i_with_dot(case_sensitive):
|
||||
sc1, sc2, sc3, sc4 = spacy.util.get_search_char_byte_arrays("İ", case_sensitive)
|
||||
search_chars, width_offsets = spacy.util.get_search_char_byte_arrays(
|
||||
"İ", case_sensitive
|
||||
)
|
||||
if case_sensitive:
|
||||
assert sc2 == "İ".encode("utf-8")
|
||||
assert sc1 == sc3 == sc4 == b""
|
||||
assert search_chars == "İ".encode("utf-8")
|
||||
assert width_offsets == b"\x00\x00\x02\x02\x02"
|
||||
else:
|
||||
assert sc1 == b"i"
|
||||
assert sc2 == b"\xcc\x87"
|
||||
assert sc3 == sc4 == b""
|
||||
assert search_chars == b"i\xcc\x87"
|
||||
assert width_offsets == b"\x00\x01\x03\x03\x03"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("case_sensitive", [True, False])
|
||||
def test_turkish_i_with_dot_and_normal_i(case_sensitive):
|
||||
sc1, sc2, sc3, sc4 = spacy.util.get_search_char_byte_arrays("İI", case_sensitive)
|
||||
search_chars, width_offsets = spacy.util.get_search_char_byte_arrays(
|
||||
"İI", case_sensitive
|
||||
)
|
||||
if case_sensitive:
|
||||
assert sc1 == b"I"
|
||||
assert sc2 == "İ".encode("utf-8")
|
||||
assert sc3 == sc4 == b""
|
||||
assert search_chars == "Iİ".encode("utf-8")
|
||||
assert width_offsets == b"\x00\x01\x03\x03\x03"
|
||||
else:
|
||||
assert sc1 == b"i"
|
||||
assert sc2 == b"\xcc\x87"
|
||||
assert sc3 == sc4 == b""
|
||||
assert search_chars == b"i\xcc\x87"
|
||||
assert width_offsets == b"\x00\x01\x03\x03\x03"
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from turtle import width
|
||||
from typing import List, Mapping, NoReturn, Union, Dict, Any, Set, cast
|
||||
from typing import Optional, Iterable, Callable, Tuple, Type
|
||||
from typing import Iterator, Pattern, Generator, TYPE_CHECKING
|
||||
|
@ -1738,41 +1739,41 @@ def all_equal(iterable):
|
|||
|
||||
|
||||
def get_search_char_byte_arrays(
|
||||
search_chars: str, case_sensitive: bool
|
||||
) -> Tuple[bytes, bytes, bytes, bytes]:
|
||||
search_char_string: str, case_sensitive: bool
|
||||
) -> Tuple[bytes, bytes]:
|
||||
"""
|
||||
This function supports the rich feature extractor. It splits the UTF-8 representation
|
||||
of *search_chars* into separate byte arrays containing 1-, 2-, 3-, and 4-byte
|
||||
characters respectively. Any duplicates in *search_chars* are removed, and *search_chars*
|
||||
is converted to lower case if *case_sensitive==False*.
|
||||
This function supports the rich feature extractor. It orders the characters in
|
||||
*search_char_string*, removing any duplicates, encodes them with UTF-8, and
|
||||
returns the result together with a byte array containing the offsets where the
|
||||
characters of various byte lengths start within the result, i.e.
|
||||
|
||||
<1-byte-start>, <2-byte-start>, <3-byte-start>, <4-byte-start>, <4-byte-end>.
|
||||
|
||||
If the string does not contain any characters of length *n*,
|
||||
<n_byte_start> == <n+1_byte_start>.
|
||||
"""
|
||||
|
||||
sc1 = bytearray()
|
||||
sc2 = bytearray()
|
||||
sc3 = bytearray()
|
||||
sc4 = bytearray()
|
||||
if not case_sensitive:
|
||||
search_chars = search_chars.lower()
|
||||
ordered_search_chars = "".join(sorted(set(search_chars)))
|
||||
encoded_search_char_bytes = ordered_search_chars.encode("UTF-8")
|
||||
working_start = 0
|
||||
for idx in range(len(encoded_search_char_bytes) + 1):
|
||||
if idx == 0:
|
||||
continue
|
||||
search_char_string = search_char_string.lower()
|
||||
ordered_search_char_string = "".join(sorted(set(search_char_string)))
|
||||
search_chars = ordered_search_char_string.encode("UTF-8")
|
||||
width_offsets = [0, -1, -1, -1, -1]
|
||||
working_start = -1
|
||||
working_width = 1
|
||||
for idx in range(len(search_chars) + 1):
|
||||
if (
|
||||
idx == len(encoded_search_char_bytes)
|
||||
or encoded_search_char_bytes[idx] & 0xC0 != 0x80 # not continuation byte
|
||||
idx == len(search_chars)
|
||||
or search_chars[idx] & 0xC0 != 0x80 # not continuation byte
|
||||
):
|
||||
char_length = idx - working_start
|
||||
if char_length == 1:
|
||||
sc1.extend(encoded_search_char_bytes[working_start:idx])
|
||||
elif char_length == 2:
|
||||
sc2.extend(encoded_search_char_bytes[working_start:idx])
|
||||
elif char_length == 3:
|
||||
sc3.extend(encoded_search_char_bytes[working_start:idx])
|
||||
elif char_length == 4:
|
||||
sc4.extend(encoded_search_char_bytes[working_start:idx])
|
||||
else:
|
||||
this_width = idx - working_start
|
||||
if this_width > 4 or this_width < working_width:
|
||||
raise RuntimeError(Errors.E1050)
|
||||
if this_width > working_width:
|
||||
width_offsets[this_width - 1] = working_start
|
||||
working_width = this_width
|
||||
working_start = idx
|
||||
return bytes(sc1), bytes(sc2), bytes(sc3), bytes(sc4)
|
||||
width_offsets[this_width] = idx
|
||||
for i in range(5):
|
||||
if width_offsets[i] == -1:
|
||||
width_offsets[i] = width_offsets[i - 1]
|
||||
return search_chars, bytes((width_offsets))
|
||||
|
|
Loading…
Reference in New Issue
Block a user