mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-27 09:44:36 +03:00
Use max(uint64) for OOV lexeme rank (#5303)
* Use max(uint64) for OOV lexeme rank * Add test for default OOV rank * Revert back to thinc==7.4.0 Requiring the updated version of thinc was unnecessary. * Define OOV_RANK in one place Define OOV_RANK in one place in `util`. * Fix formatting [ci skip] * Switch to external definitions of max(uint64) Switch to external defintions of max(uint64) and confirm that they are equal.
This commit is contained in:
parent
3d2c308906
commit
98c59027ed
|
@ -289,7 +289,7 @@ def link_vectors_to_models(vocab):
|
||||||
if word.orth in vectors.key2row:
|
if word.orth in vectors.key2row:
|
||||||
word.rank = vectors.key2row[word.orth]
|
word.rank = vectors.key2row[word.orth]
|
||||||
else:
|
else:
|
||||||
word.rank = 0
|
word.rank = util.OOV_RANK
|
||||||
data = ops.asarray(vectors.data)
|
data = ops.asarray(vectors.data)
|
||||||
# Set an entry here, so that vectors are accessed by StaticVectors
|
# Set an entry here, so that vectors are accessed by StaticVectors
|
||||||
# (unideal, I know)
|
# (unideal, I know)
|
||||||
|
|
|
@ -16,7 +16,7 @@ from wasabi import msg
|
||||||
|
|
||||||
from ..vectors import Vectors
|
from ..vectors import Vectors
|
||||||
from ..errors import Errors, Warnings, user_warning
|
from ..errors import Errors, Warnings, user_warning
|
||||||
from ..util import ensure_path, get_lang_class
|
from ..util import ensure_path, get_lang_class, OOV_RANK
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import ftfy
|
import ftfy
|
||||||
|
@ -148,7 +148,7 @@ def create_model(lang, lex_attrs, name=None):
|
||||||
lang_class = get_lang_class(lang)
|
lang_class = get_lang_class(lang)
|
||||||
nlp = lang_class()
|
nlp = lang_class()
|
||||||
for lexeme in nlp.vocab:
|
for lexeme in nlp.vocab:
|
||||||
lexeme.rank = 0
|
lexeme.rank = OOV_RANK
|
||||||
lex_added = 0
|
lex_added = 0
|
||||||
for attrs in lex_attrs:
|
for attrs in lex_attrs:
|
||||||
if "settings" in attrs:
|
if "settings" in attrs:
|
||||||
|
|
|
@ -10,6 +10,7 @@ from numpy cimport ndarray
|
||||||
|
|
||||||
|
|
||||||
cdef LexemeC EMPTY_LEXEME
|
cdef LexemeC EMPTY_LEXEME
|
||||||
|
cdef attr_t OOV_RANK
|
||||||
|
|
||||||
cdef class Lexeme:
|
cdef class Lexeme:
|
||||||
cdef LexemeC* c
|
cdef LexemeC* c
|
||||||
|
|
|
@ -11,6 +11,7 @@ np.import_array()
|
||||||
import numpy
|
import numpy
|
||||||
from thinc.neural.util import get_array_module
|
from thinc.neural.util import get_array_module
|
||||||
|
|
||||||
|
from libc.stdint cimport UINT64_MAX
|
||||||
from .typedefs cimport attr_t, flags_t
|
from .typedefs cimport attr_t, flags_t
|
||||||
from .attrs cimport IS_ALPHA, IS_ASCII, IS_DIGIT, IS_LOWER, IS_PUNCT, IS_SPACE
|
from .attrs cimport IS_ALPHA, IS_ASCII, IS_DIGIT, IS_LOWER, IS_PUNCT, IS_SPACE
|
||||||
from .attrs cimport IS_TITLE, IS_UPPER, LIKE_URL, LIKE_NUM, LIKE_EMAIL, IS_STOP
|
from .attrs cimport IS_TITLE, IS_UPPER, LIKE_URL, LIKE_NUM, LIKE_EMAIL, IS_STOP
|
||||||
|
@ -21,7 +22,9 @@ from .attrs import intify_attrs
|
||||||
from .errors import Errors, Warnings, user_warning
|
from .errors import Errors, Warnings, user_warning
|
||||||
|
|
||||||
|
|
||||||
|
OOV_RANK = UINT64_MAX
|
||||||
memset(&EMPTY_LEXEME, 0, sizeof(LexemeC))
|
memset(&EMPTY_LEXEME, 0, sizeof(LexemeC))
|
||||||
|
EMPTY_LEXEME.id = OOV_RANK
|
||||||
|
|
||||||
|
|
||||||
cdef class Lexeme:
|
cdef class Lexeme:
|
||||||
|
|
|
@ -2,7 +2,9 @@
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import numpy
|
||||||
from spacy.attrs import IS_ALPHA, IS_DIGIT
|
from spacy.attrs import IS_ALPHA, IS_DIGIT
|
||||||
|
from spacy.util import OOV_RANK
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("text1,prob1,text2,prob2", [("NOUN", -1, "opera", -2)])
|
@pytest.mark.parametrize("text1,prob1,text2,prob2", [("NOUN", -1, "opera", -2)])
|
||||||
|
@ -69,3 +71,10 @@ def test_lexeme_bytes_roundtrip(en_vocab):
|
||||||
assert one.orth == alpha.orth
|
assert one.orth == alpha.orth
|
||||||
assert one.lower == alpha.lower
|
assert one.lower == alpha.lower
|
||||||
assert one.lower_ == alpha.lower_
|
assert one.lower_ == alpha.lower_
|
||||||
|
|
||||||
|
|
||||||
|
def test_vocab_lexeme_oov_rank(en_vocab):
|
||||||
|
"""Test that default rank is OOV_RANK."""
|
||||||
|
lex = en_vocab["word"]
|
||||||
|
assert OOV_RANK == numpy.iinfo(numpy.uint64).max
|
||||||
|
assert lex.rank == OOV_RANK
|
||||||
|
|
|
@ -12,6 +12,7 @@ from thinc.neural.ops import NumpyOps
|
||||||
import functools
|
import functools
|
||||||
import itertools
|
import itertools
|
||||||
import numpy.random
|
import numpy.random
|
||||||
|
import numpy
|
||||||
import srsly
|
import srsly
|
||||||
import catalogue
|
import catalogue
|
||||||
import sys
|
import sys
|
||||||
|
@ -34,6 +35,7 @@ from .errors import Errors, Warnings, deprecation_warning
|
||||||
|
|
||||||
_data_path = Path(__file__).parent / "data"
|
_data_path = Path(__file__).parent / "data"
|
||||||
_PRINT_ENV = False
|
_PRINT_ENV = False
|
||||||
|
OOV_RANK = numpy.iinfo(numpy.uint64).max
|
||||||
|
|
||||||
|
|
||||||
class registry(object):
|
class registry(object):
|
||||||
|
|
|
@ -7,7 +7,7 @@ import srsly
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from thinc.neural.util import get_array_module
|
from thinc.neural.util import get_array_module
|
||||||
|
|
||||||
from .lexeme cimport EMPTY_LEXEME
|
from .lexeme cimport EMPTY_LEXEME, OOV_RANK
|
||||||
from .lexeme cimport Lexeme
|
from .lexeme cimport Lexeme
|
||||||
from .typedefs cimport attr_t
|
from .typedefs cimport attr_t
|
||||||
from .tokens.token cimport Token
|
from .tokens.token cimport Token
|
||||||
|
@ -165,9 +165,9 @@ cdef class Vocab:
|
||||||
lex.orth = self.strings.add(string)
|
lex.orth = self.strings.add(string)
|
||||||
lex.length = len(string)
|
lex.length = len(string)
|
||||||
if self.vectors is not None:
|
if self.vectors is not None:
|
||||||
lex.id = self.vectors.key2row.get(lex.orth, 0)
|
lex.id = self.vectors.key2row.get(lex.orth, OOV_RANK)
|
||||||
else:
|
else:
|
||||||
lex.id = 0
|
lex.id = OOV_RANK
|
||||||
if self.lex_attr_getters is not None:
|
if self.lex_attr_getters is not None:
|
||||||
for attr, func in self.lex_attr_getters.items():
|
for attr, func in self.lex_attr_getters.items():
|
||||||
value = func(string)
|
value = func(string)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user