mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-02 11:20:19 +03:00
Return 64-bit integers
This commit is contained in:
parent
fc72ee21c5
commit
d575b9f8d4
|
@ -2,6 +2,7 @@ import weakref
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
from numpy.testing import assert_array_equal
|
from numpy.testing import assert_array_equal
|
||||||
|
import murmurhash.mrmr
|
||||||
import pytest
|
import pytest
|
||||||
import warnings
|
import warnings
|
||||||
from thinc.api import NumpyOps, get_current_ops
|
from thinc.api import NumpyOps, get_current_ops
|
||||||
|
@ -976,39 +977,70 @@ def test_doc_spans_setdefault(en_tokenizer):
|
||||||
doc.spans.setdefault("key3", default=SpanGroup(doc, spans=[doc[0:1], doc[1:2]]))
|
doc.spans.setdefault("key3", default=SpanGroup(doc, spans=[doc[0:1], doc[1:2]]))
|
||||||
assert len(doc.spans["key3"]) == 2
|
assert len(doc.spans["key3"]) == 2
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"case_sensitive", [True, False]
|
"case_sensitive", [True, False]
|
||||||
)
|
)
|
||||||
def test_get_affixes_good_case(en_tokenizer, case_sensitive):
|
def test_get_affix_hashes_good_case(en_tokenizer, case_sensitive):
|
||||||
|
|
||||||
|
def _get_unsigned_64_bit_hash(input:str) -> int:
|
||||||
|
if not case_sensitive:
|
||||||
|
input = input.lower()
|
||||||
|
return numpy.asarray([murmurhash.mrmr.hash(input)]).astype("uint64")[0]
|
||||||
|
|
||||||
doc = en_tokenizer("spaCy✨ and Prodigy")
|
doc = en_tokenizer("spaCy✨ and Prodigy")
|
||||||
prefixes = doc.get_affixes(False, case_sensitive, 1, 5, "", 2, 3)
|
prefixes = doc.get_affix_hashes(False, case_sensitive, 1, 5, "", 2, 3)
|
||||||
suffixes = doc.get_affixes(True, case_sensitive, 2, 6, "xx✨rP", 2, 3)
|
suffixes = doc.get_affix_hashes(True, case_sensitive, 2, 6, "xx✨rP", 1, 3)
|
||||||
assert prefixes[3][3, 3] == suffixes[3][3, 3]
|
assert prefixes[0][0] == _get_unsigned_64_bit_hash("s")
|
||||||
assert prefixes[3][3, 2] == suffixes[3][3, 4]
|
assert prefixes[0][1] == _get_unsigned_64_bit_hash("sp")
|
||||||
assert suffixes[4][1].tolist() == [10024, 0]
|
assert prefixes[0][2] == _get_unsigned_64_bit_hash("spa")
|
||||||
|
assert prefixes[0][3] == _get_unsigned_64_bit_hash("spaC")
|
||||||
|
assert prefixes[0][4] == _get_unsigned_64_bit_hash(" ")
|
||||||
|
assert prefixes[1][0] == _get_unsigned_64_bit_hash("✨")
|
||||||
|
assert prefixes[1][1] == _get_unsigned_64_bit_hash("✨")
|
||||||
|
assert prefixes[1][2] == _get_unsigned_64_bit_hash("✨")
|
||||||
|
assert prefixes[1][3] == _get_unsigned_64_bit_hash("✨")
|
||||||
|
assert prefixes[1][4] == _get_unsigned_64_bit_hash(" ")
|
||||||
|
assert prefixes[2][0] == _get_unsigned_64_bit_hash("a")
|
||||||
|
assert prefixes[2][1] == _get_unsigned_64_bit_hash("an")
|
||||||
|
assert prefixes[2][2] == _get_unsigned_64_bit_hash("and")
|
||||||
|
assert prefixes[2][3] == _get_unsigned_64_bit_hash("and")
|
||||||
|
assert prefixes[2][4] == _get_unsigned_64_bit_hash(" ")
|
||||||
|
assert prefixes[3][0] == _get_unsigned_64_bit_hash("P")
|
||||||
|
assert prefixes[3][1] == _get_unsigned_64_bit_hash("Pr")
|
||||||
|
assert prefixes[3][2] == _get_unsigned_64_bit_hash("Pro")
|
||||||
|
assert prefixes[3][3] == _get_unsigned_64_bit_hash("Prod")
|
||||||
|
assert prefixes[3][4] == _get_unsigned_64_bit_hash(" ")
|
||||||
|
|
||||||
|
assert suffixes[0][0] == _get_unsigned_64_bit_hash("yC")
|
||||||
|
assert suffixes[0][1] == _get_unsigned_64_bit_hash("yCa")
|
||||||
|
assert suffixes[0][2] == _get_unsigned_64_bit_hash("yCap")
|
||||||
|
assert suffixes[0][3] == _get_unsigned_64_bit_hash("yCaps")
|
||||||
|
assert suffixes[0][4] == _get_unsigned_64_bit_hash(" ")
|
||||||
|
assert suffixes[0][5] == _get_unsigned_64_bit_hash(" ")
|
||||||
|
assert suffixes[1][0] == _get_unsigned_64_bit_hash("✨")
|
||||||
|
assert suffixes[1][1] == _get_unsigned_64_bit_hash("✨")
|
||||||
|
assert suffixes[1][2] == _get_unsigned_64_bit_hash("✨")
|
||||||
|
assert suffixes[1][3] == _get_unsigned_64_bit_hash("✨")
|
||||||
|
assert suffixes[1][4] == _get_unsigned_64_bit_hash("✨")
|
||||||
|
assert suffixes[1][5] == _get_unsigned_64_bit_hash("✨ ")
|
||||||
|
assert suffixes[2][0] == _get_unsigned_64_bit_hash("dn")
|
||||||
|
assert suffixes[2][1] == _get_unsigned_64_bit_hash("dna")
|
||||||
|
assert suffixes[2][2] == _get_unsigned_64_bit_hash("dna")
|
||||||
|
assert suffixes[2][3] == _get_unsigned_64_bit_hash("dna")
|
||||||
|
assert suffixes[2][4] == _get_unsigned_64_bit_hash(" ")
|
||||||
|
assert suffixes[2][5] == _get_unsigned_64_bit_hash(" ")
|
||||||
|
assert suffixes[3][0] == _get_unsigned_64_bit_hash("yg")
|
||||||
|
assert suffixes[3][1] == _get_unsigned_64_bit_hash("ygi")
|
||||||
|
assert suffixes[3][2] == _get_unsigned_64_bit_hash("ygid")
|
||||||
|
assert suffixes[3][3] == _get_unsigned_64_bit_hash("ygido")
|
||||||
|
assert suffixes[3][4] == _get_unsigned_64_bit_hash("r")
|
||||||
|
|
||||||
if case_sensitive:
|
if case_sensitive:
|
||||||
assert suffixes[4][3].tolist() == [114, 80]
|
assert suffixes[3][5] == _get_unsigned_64_bit_hash("rP")
|
||||||
else:
|
else:
|
||||||
assert suffixes[4][3].tolist() == [114, 112]
|
assert suffixes[3][5] == _get_unsigned_64_bit_hash("r ")
|
||||||
suffixes = doc.get_affixes(True, case_sensitive, 2, 6, "xx✨rp", 2, 3)
|
|
||||||
if case_sensitive:
|
|
||||||
assert suffixes[4][3].tolist() == [114, 0]
|
|
||||||
else:
|
|
||||||
assert suffixes[4][3].tolist() == [114, 112]
|
|
||||||
|
|
||||||
|
|
||||||
|
# check values are the same cross-platform
|
||||||
def test_get_affixes_4_byte_normal_char(en_tokenizer):
|
assert prefixes[0][3] == 18446744072456113490 if case_sensitive else 18446744071614199016
|
||||||
doc = en_tokenizer("and𐌞")
|
assert suffixes[1][0] == 910783208
|
||||||
suffixes = doc.get_affixes(True, False, 2, 6, "a", 1, 2)
|
assert suffixes[2][5] == 1696457176
|
||||||
for i in range(0, 4):
|
|
||||||
assert suffixes[i][0, 1] == 55296
|
|
||||||
assert suffixes[3][0, 4] == 97
|
|
||||||
assert suffixes[4][0, 0] == 97
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_affixes_4_byte_special_char(en_tokenizer):
|
|
||||||
doc = en_tokenizer("and𐌞")
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
doc.get_affixes(True, False, 2, 6, "𐌞", 2, 3)
|
|
|
@ -9,7 +9,6 @@ from ..attrs cimport attr_id_t
|
||||||
|
|
||||||
cdef attr_t get_token_attr(const TokenC* token, attr_id_t feat_name) nogil
|
cdef attr_t get_token_attr(const TokenC* token, attr_id_t feat_name) nogil
|
||||||
cdef attr_t get_token_attr_for_matcher(const TokenC* token, attr_id_t feat_name) nogil
|
cdef attr_t get_token_attr_for_matcher(const TokenC* token, attr_id_t feat_name) nogil
|
||||||
cdef const unsigned char[:] get_utf16_memoryview(str unicode_string, bint check_2_bytes)
|
|
||||||
|
|
||||||
|
|
||||||
ctypedef const LexemeC* const_Lexeme_ptr
|
ctypedef const LexemeC* const_Lexeme_ptr
|
||||||
|
|
|
@ -16,6 +16,7 @@ import srsly
|
||||||
from thinc.api import get_array_module, get_current_ops
|
from thinc.api import get_array_module, get_current_ops
|
||||||
from thinc.util import copy_array
|
from thinc.util import copy_array
|
||||||
import warnings
|
import warnings
|
||||||
|
import murmurhash.mrmr
|
||||||
|
|
||||||
from .span cimport Span
|
from .span cimport Span
|
||||||
from .token cimport MISSING_DEP
|
from .token cimport MISSING_DEP
|
||||||
|
@ -94,18 +95,6 @@ cdef attr_t get_token_attr_for_matcher(const TokenC* token, attr_id_t feat_name)
|
||||||
return get_token_attr(token, feat_name)
|
return get_token_attr(token, feat_name)
|
||||||
|
|
||||||
|
|
||||||
cdef const unsigned char[:] get_utf16_memoryview(str unicode_string, bint check_2_bytes):
|
|
||||||
"""
|
|
||||||
Returns a memory view of the UTF-16 representation of a string with the default endianness of the platform.
|
|
||||||
Throws a ValueError if *check_2_bytes == True* and one or more characters in the UTF-16 representation
|
|
||||||
occupy four bytes rather than two.
|
|
||||||
"""
|
|
||||||
cdef const unsigned char[:] view = memoryview(unicode_string.encode("UTF-16"))[2:] # first two bytes are endianness
|
|
||||||
if check_2_bytes and len(unicode_string) * 2 != len(view):
|
|
||||||
raise ValueError(Errors.E1044)
|
|
||||||
return view
|
|
||||||
|
|
||||||
|
|
||||||
class SetEntsDefault(str, Enum):
|
class SetEntsDefault(str, Enum):
|
||||||
blocked = "blocked"
|
blocked = "blocked"
|
||||||
missing = "missing"
|
missing = "missing"
|
||||||
|
@ -1746,63 +1735,31 @@ cdef class Doc:
|
||||||
j += 1
|
j += 1
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def get_affixes(self, bint suffs_not_prefs, bint case_sensitive, unsigned int len_start, unsigned int len_end,
|
def get_affix_hashes(self, bint suffs_not_prefs, bint lower_not_orth, unsigned int len_start, unsigned int len_end,
|
||||||
str special_chars, unsigned int sc_len_start, unsigned int sc_len_end):
|
str special_chars, unsigned int sc_len_start, unsigned int sc_len_end):
|
||||||
"""
|
"""
|
||||||
TODO
|
TODO
|
||||||
"""
|
"""
|
||||||
if case_sensitive:
|
cdef unsigned int token_index, norm_hash_index, spec_hash_index
|
||||||
token_attrs = [t.orth_ for t in self]
|
cdef str token_string, specials_string, working_substring
|
||||||
else:
|
cdef unsigned int num_norm_hashes = len_end - len_start, num_spec_hashes = sc_len_end - sc_len_start, num_tokens = len(self)
|
||||||
token_attrs = [t.lower_ for t in self]
|
cdef np.ndarray[np.int64_t, ndim=2] output = numpy.empty((num_tokens, num_norm_hashes + num_spec_hashes), dtype="int64")
|
||||||
special_chars = special_chars.lower()
|
|
||||||
cdef unsigned int sc_len = len(special_chars)
|
|
||||||
cdef const unsigned char[:] sc_bytes = get_utf16_memoryview(special_chars, True)
|
|
||||||
cdef np.ndarray[np.uint16_t, ndim=1] scs = numpy.ndarray((sc_len,), buffer=sc_bytes, dtype="uint16")
|
|
||||||
cdef np.ndarray[np.uint16_t, ndim=2] output
|
|
||||||
cdef unsigned int num_tokens = len(self), num_normal_affixes = len_end - len_start, working_len
|
|
||||||
|
|
||||||
outputs = []
|
for token_index in range(num_tokens):
|
||||||
for working_len in range(num_normal_affixes):
|
token_string = self[token_index].orth_ if lower_not_orth else self[token_index].lower_
|
||||||
output = numpy.zeros((num_tokens, len_start + working_len), dtype="uint16")
|
if suffs_not_prefs:
|
||||||
outputs.append(output)
|
token_string = token_string[::-1]
|
||||||
for working_len in range(sc_len_end - sc_len_start):
|
|
||||||
output = numpy.zeros((num_tokens, sc_len_start + working_len), dtype="uint16")
|
|
||||||
outputs.append(output)
|
|
||||||
|
|
||||||
cdef const unsigned char[:] token_bytes
|
for norm_hash_index in range(num_norm_hashes):
|
||||||
cdef np.uint16_t working_char
|
working_substring = token_string[: len_start + norm_hash_index]
|
||||||
cdef unsigned int token_bytes_len, token_idx, char_idx, sc_char_idx, sc_test_idx, working_sc_len
|
output[token_index, norm_hash_index] = murmurhash.mrmr.hash(working_substring)
|
||||||
cdef unsigned int char_byte_idx
|
|
||||||
|
|
||||||
for token_idx in range(num_tokens):
|
specials_string = "".join([c for c in token_string if c in special_chars]) + " " * sc_len_end
|
||||||
token_bytes = get_utf16_memoryview(token_attrs[token_idx], False)
|
for spec_hash_index in range(num_spec_hashes):
|
||||||
char_idx = 0
|
working_substring = specials_string[: sc_len_start + spec_hash_index]
|
||||||
sc_char_idx = 0
|
output[token_index, num_norm_hashes + spec_hash_index] = murmurhash.mrmr.hash(working_substring)
|
||||||
token_bytes_len = len(token_bytes)
|
|
||||||
|
|
||||||
while (char_idx < len_end - 1 or sc_char_idx < sc_len_end - 1) and char_idx * 2 < token_bytes_len:
|
return output.astype("uint64", casting="unsafe", copy=False)
|
||||||
if suffs_not_prefs:
|
|
||||||
char_byte_idx = token_bytes_len - 2 * (char_idx + 1)
|
|
||||||
else:
|
|
||||||
char_byte_idx = char_idx * 2
|
|
||||||
working_char = (<np.uint16_t*> &token_bytes[char_byte_idx])[0]
|
|
||||||
for working_len in range(len_end-1, len_start-1, -1):
|
|
||||||
if char_idx >= working_len:
|
|
||||||
break
|
|
||||||
outputs[working_len - len_start][token_idx, char_idx] = working_char
|
|
||||||
sc_test_idx = 0
|
|
||||||
while sc_len > sc_test_idx:
|
|
||||||
if working_char == scs[sc_test_idx]:
|
|
||||||
for working_sc_len in range(sc_len_end-1, sc_len_start-1, -1):
|
|
||||||
if sc_char_idx >= working_sc_len:
|
|
||||||
break
|
|
||||||
outputs[num_normal_affixes + working_sc_len - sc_len_start][token_idx, sc_char_idx] = working_char
|
|
||||||
sc_char_idx += 1
|
|
||||||
break
|
|
||||||
sc_test_idx += 1
|
|
||||||
char_idx += 1
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_array_attrs():
|
def _get_array_attrs():
|
||||||
|
|
Loading…
Reference in New Issue
Block a user