diff --git a/spacy/lang/lex_attrs.py b/spacy/lang/lex_attrs.py index 6ed981a06..0465c448a 100644 --- a/spacy/lang/lex_attrs.py +++ b/spacy/lang/lex_attrs.py @@ -23,21 +23,21 @@ _tlds = set( ) -def is_punct(text: str) -> bool: +def is_punct(vocab, text: str) -> bool: for char in text: if not unicodedata.category(char).startswith("P"): return False return True -def is_ascii(text: str) -> bool: +def is_ascii(vocab, text: str) -> bool: for char in text: if ord(char) >= 128: return False return True -def like_num(text: str) -> bool: +def like_num(vocab, text: str) -> bool: if text.startswith(("+", "-", "±", "~")): text = text[1:] # can be overwritten by lang with list of number words @@ -51,31 +51,31 @@ def like_num(text: str) -> bool: return False -def is_bracket(text: str) -> bool: +def is_bracket(vocab, text: str) -> bool: brackets = ("(", ")", "[", "]", "{", "}", "<", ">") return text in brackets -def is_quote(text: str) -> bool: +def is_quote(vocab, text: str) -> bool: # fmt: off quotes = ('"', "'", "`", "«", "»", "‘", "’", "‚", "‛", "“", "”", "„", "‟", "‹", "›", "❮", "❯", "''", "``") # fmt: on return text in quotes -def is_left_punct(text: str) -> bool: +def is_left_punct(vocab, text: str) -> bool: # fmt: off left_punct = ("(", "[", "{", "<", '"', "'", "«", "‘", "‚", "‛", "“", "„", "‟", "‹", "❮", "``") # fmt: on return text in left_punct -def is_right_punct(text: str) -> bool: +def is_right_punct(vocab, text: str) -> bool: right_punct = (")", "]", "}", ">", '"', "'", "»", "’", "”", "›", "❯", "''") return text in right_punct -def is_currency(text: str) -> bool: +def is_currency(vocab, text: str) -> bool: # can be overwritten by lang with list of currency words, e.g. dollar, euro for char in text: if unicodedata.category(char) != "Sc": @@ -83,11 +83,11 @@ def is_currency(text: str) -> bool: return True -def like_email(text: str) -> bool: +def like_email(vocab, text: str) -> bool: return bool(_like_email(text)) -def like_url(text: str) -> bool: +def like_url(vocab, text: str) -> bool: # We're looking for things that function in text like URLs. So, valid URL # or not, anything they say http:// is going to be good. if text.startswith("http://") or text.startswith("https://"): @@ -115,7 +115,7 @@ def like_url(text: str) -> bool: return False -def word_shape(text: str) -> str: +def word_shape(vocab, text: str) -> str: if len(text) >= 100: return "LONG" shape = [] @@ -142,55 +142,54 @@ def word_shape(text: str) -> str: return "".join(shape) -def lower(string: str) -> str: +def lower(vocab, string: str) -> str: return string.lower() -def prefix(string: str) -> str: +def prefix(vocab, string: str) -> str: return string[0] -def suffix(string: str) -> str: +def suffix(vocab, string: str) -> str: return string[-3:] -def is_alpha(string: str) -> bool: +def is_alpha(vocab, string: str) -> bool: return string.isalpha() -def is_digit(string: str) -> bool: +def is_digit(vocab, string: str) -> bool: return string.isdigit() -def is_lower(string: str) -> bool: +def is_lower(vocab, string: str) -> bool: return string.islower() -def is_space(string: str) -> bool: +def is_space(vocab, string: str) -> bool: return string.isspace() -def is_title(string: str) -> bool: +def is_title(vocab, string: str) -> bool: return string.istitle() -def is_upper(string: str) -> bool: +def is_upper(vocab, string: str) -> bool: return string.isupper() -def is_stop(string: str, stops: Set[str] = set()) -> bool: +def is_stop(vocab, string: str) -> bool: + stops = vocab.lex_attr_data.get("stops", []) return string.lower() in stops -def get_lang(text: str, lang: str = "") -> str: - # This function is partially applied so lang code can be passed in - # automatically while still allowing pickling - return lang +def norm(vocab, string: str) -> str: + return vocab.lookups.get_table("lexeme_norm", {}).get(string, string.lower()) LEX_ATTRS = { attrs.LOWER: lower, - attrs.NORM: lower, + attrs.NORM: norm, attrs.PREFIX: prefix, attrs.SUFFIX: suffix, attrs.IS_ALPHA: is_alpha, diff --git a/spacy/language.py b/spacy/language.py index d2b89029d..1e4220f86 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -74,6 +74,7 @@ class BaseDefaults: url_match: Optional[Callable] = URL_MATCH syntax_iterators: Dict[str, Callable] = {} lex_attr_getters: Dict[int, Callable[[str], Any]] = {} + lex_attr_data: Dict[str, Any] = {} stop_words: Set[str] = set() writing_system = {"direction": "ltr", "has_case": True, "has_letters": True} diff --git a/spacy/lexeme.pxd b/spacy/lexeme.pxd index 2d14edcd6..9483a4457 100644 --- a/spacy/lexeme.pxd +++ b/spacy/lexeme.pxd @@ -2,7 +2,7 @@ from numpy cimport ndarray from .typedefs cimport attr_t, hash_t, flags_t, len_t, tag_t from .attrs cimport attr_id_t -from .attrs cimport ID, ORTH, LOWER, NORM, SHAPE, PREFIX, SUFFIX, LENGTH, LANG +from .attrs cimport ID, ORTH, LOWER, NORM, SHAPE, PREFIX, SUFFIX, LENGTH from .structs cimport LexemeC from .vocab cimport Vocab @@ -40,8 +40,6 @@ cdef class Lexeme: lex.prefix = value elif name == SUFFIX: lex.suffix = value - elif name == LANG: - lex.lang = value @staticmethod cdef inline attr_t get_struct_attr(const LexemeC* lex, attr_id_t feat_name) nogil: @@ -66,8 +64,6 @@ cdef class Lexeme: return lex.suffix elif feat_name == LENGTH: return lex.length - elif feat_name == LANG: - return lex.lang else: return 0 diff --git a/spacy/lexeme.pyx b/spacy/lexeme.pyx index e57098f17..d767d152d 100644 --- a/spacy/lexeme.pyx +++ b/spacy/lexeme.pyx @@ -247,13 +247,9 @@ cdef class Lexeme: cluster_table = self.vocab.lookups.get_table("lexeme_cluster", {}) cluster_table[self.c.orth] = x - property lang: - """RETURNS (uint64): Language of the parent vocabulary.""" - def __get__(self): - return self.c.lang - - def __set__(self, attr_t x): - self.c.lang = x + @property + def lang(self): + return self.vocab.strings[self.vocab.lang] property prob: """RETURNS (float): Smoothed log probability estimate of the lexeme's @@ -316,13 +312,9 @@ cdef class Lexeme: def __set__(self, str x): self.c.suffix = self.vocab.strings.add(x) - property lang_: - """RETURNS (str): Language of the parent vocabulary.""" - def __get__(self): - return self.vocab.strings[self.c.lang] - - def __set__(self, str x): - self.c.lang = self.vocab.strings.add(x) + @property + def lang_(self): + return self.vocab.lang property flags: """RETURNS (uint64): Container of the lexeme's binary flags.""" diff --git a/spacy/structs.pxd b/spacy/structs.pxd index b9b6f6ba8..4d64e8cd4 100644 --- a/spacy/structs.pxd +++ b/spacy/structs.pxd @@ -11,8 +11,6 @@ from .parts_of_speech cimport univ_pos_t cdef struct LexemeC: flags_t flags - attr_t lang - attr_t id attr_t length diff --git a/spacy/tests/lang/test_attrs.py b/spacy/tests/lang/test_attrs.py index 1e1bae08c..016f94ca2 100644 --- a/spacy/tests/lang/test_attrs.py +++ b/spacy/tests/lang/test_attrs.py @@ -9,8 +9,8 @@ from spacy.lang.lex_attrs import like_url, word_shape @pytest.mark.parametrize("word", ["the"]) @pytest.mark.issue(1889) -def test_issue1889(word): - assert is_stop(word, STOP_WORDS) == is_stop(word.upper(), STOP_WORDS) +def test_issue1889(en_vocab, word): + assert is_stop(en_vocab, word) == is_stop(en_vocab, word.upper()) @pytest.mark.parametrize("text", ["dog"]) @@ -59,13 +59,13 @@ def test_attrs_ent_iob_intify(): @pytest.mark.parametrize("text,match", [(",", True), (" ", False), ("a", False)]) -def test_lex_attrs_is_punct(text, match): - assert is_punct(text) == match +def test_lex_attrs_is_punct(en_vocab, text, match): + assert is_punct(en_vocab, text) == match @pytest.mark.parametrize("text,match", [(",", True), ("£", False), ("♥", False)]) -def test_lex_attrs_is_ascii(text, match): - assert is_ascii(text) == match +def test_lex_attrs_is_ascii(en_vocab, text, match): + assert is_ascii(en_vocab, text) == match @pytest.mark.parametrize( @@ -82,8 +82,8 @@ def test_lex_attrs_is_ascii(text, match): ("dog", False), ], ) -def test_lex_attrs_is_currency(text, match): - assert is_currency(text) == match +def test_lex_attrs_is_currency(en_vocab, text, match): + assert is_currency(en_vocab, text) == match @pytest.mark.parametrize( @@ -102,8 +102,8 @@ def test_lex_attrs_is_currency(text, match): ("hello.There", False), ], ) -def test_lex_attrs_like_url(text, match): - assert like_url(text) == match +def test_lex_attrs_like_url(en_vocab, text, match): + assert like_url(en_vocab, text) == match @pytest.mark.parametrize( @@ -118,5 +118,5 @@ def test_lex_attrs_like_url(text, match): ("``,-", "``,-"), ], ) -def test_lex_attrs_word_shape(text, shape): - assert word_shape(text) == shape +def test_lex_attrs_word_shape(en_vocab, text, shape): + assert word_shape(en_vocab, text) == shape diff --git a/spacy/tests/serialize/test_serialize_vocab_strings.py b/spacy/tests/serialize/test_serialize_vocab_strings.py index fd80c3d8e..3cfcb5974 100644 --- a/spacy/tests/serialize/test_serialize_vocab_strings.py +++ b/spacy/tests/serialize/test_serialize_vocab_strings.py @@ -6,10 +6,12 @@ from thinc.api import get_current_ops import spacy from spacy.lang.en import English from spacy.strings import StringStore +from spacy.symbols import NORM from spacy.tokens import Doc from spacy.util import ensure_path, load_model from spacy.vectors import Vectors from spacy.vocab import Vocab +from spacy.lang.lex_attrs import norm from ..util import make_tempdir @@ -73,6 +75,7 @@ def test_serialize_vocab(en_vocab, text): new_vocab = Vocab().from_bytes(vocab_bytes) assert new_vocab.strings[text_hash] == text assert new_vocab.to_bytes(exclude=["lookups"]) == vocab_bytes + assert new_vocab.lex_attr_data == en_vocab.lex_attr_data @pytest.mark.parametrize("strings1,strings2", test_strings) @@ -111,12 +114,13 @@ def test_serialize_vocab_roundtrip_disk(strings1, strings2): assert [s for s in vocab1_d.strings] == [s for s in vocab2_d.strings] else: assert [s for s in vocab1_d.strings] != [s for s in vocab2_d.strings] + assert vocab1_d.lex_attr_data == vocab2_d.lex_attr_data @pytest.mark.parametrize("strings,lex_attr", test_strings_attrs) def test_serialize_vocab_lex_attrs_bytes(strings, lex_attr): - vocab1 = Vocab(strings=strings) - vocab2 = Vocab() + vocab1 = Vocab(strings=strings, lex_attr_getters={NORM: norm}) + vocab2 = Vocab(lex_attr_getters={NORM: norm}) vocab1[strings[0]].norm_ = lex_attr assert vocab1[strings[0]].norm_ == lex_attr assert vocab2[strings[0]].norm_ != lex_attr @@ -134,8 +138,8 @@ def test_deserialize_vocab_seen_entries(strings, lex_attr): @pytest.mark.parametrize("strings,lex_attr", test_strings_attrs) def test_serialize_vocab_lex_attrs_disk(strings, lex_attr): - vocab1 = Vocab(strings=strings) - vocab2 = Vocab() + vocab1 = Vocab(strings=strings, lex_attr_getters={NORM: norm}) + vocab2 = Vocab(lex_attr_getters={NORM: norm}) vocab1[strings[0]].norm_ = lex_attr assert vocab1[strings[0]].norm_ == lex_attr assert vocab2[strings[0]].norm_ != lex_attr diff --git a/spacy/tokens/token.pyx b/spacy/tokens/token.pyx index 64c707acd..03f36addc 100644 --- a/spacy/tokens/token.pyx +++ b/spacy/tokens/token.pyx @@ -288,7 +288,7 @@ cdef class Token: """RETURNS (uint64): ID of the language of the parent document's vocabulary. """ - return self.c.lex.lang + return self.doc.lang @property def idx(self): @@ -844,7 +844,7 @@ cdef class Token: """RETURNS (str): Language of the parent document's vocabulary, e.g. 'en'. """ - return self.vocab.strings[self.c.lex.lang] + return self.doc.lang_ property lemma_: """RETURNS (str): The token lemma, i.e. the base form of the word, diff --git a/spacy/util.py b/spacy/util.py index e2ca0e6a4..ef78dc142 100644 --- a/spacy/util.py +++ b/spacy/util.py @@ -1160,28 +1160,6 @@ def compile_infix_regex(entries: Iterable[Union[str, Pattern]]) -> Pattern: return re.compile(expression) -def add_lookups(default_func: Callable[[str], Any], *lookups) -> Callable[[str], Any]: - """Extend an attribute function with special cases. If a word is in the - lookups, the value is returned. Otherwise the previous function is used. - - default_func (callable): The default function to execute. - *lookups (dict): Lookup dictionary mapping string to attribute value. - RETURNS (callable): Lexical attribute getter. - """ - # This is implemented as functools.partial instead of a closure, to allow - # pickle to work. - return functools.partial(_get_attr_unless_lookup, default_func, lookups) - - -def _get_attr_unless_lookup( - default_func: Callable[[str], Any], lookups: Dict[str, Any], string: str -) -> Any: - for lookup in lookups: - if string in lookup: - return lookup[string] # type: ignore[index] - return default_func(string) - - def update_exc( base_exceptions: Dict[str, List[dict]], *addition_dicts ) -> Dict[str, List[dict]]: diff --git a/spacy/vocab.pxd b/spacy/vocab.pxd index 2db709b71..a2028b9e7 100644 --- a/spacy/vocab.pxd +++ b/spacy/vocab.pxd @@ -27,12 +27,14 @@ cdef class Vocab: cdef Pool mem cdef readonly StringStore strings cdef public Morphology morphology + cdef public str lang cdef public object _vectors cdef public object _lookups + cdef public object _lex_attr_getters + cdef public object _lex_attr_data cdef public object writing_system cdef public object get_noun_chunks cdef readonly int length - cdef public object lex_attr_getters cdef public object cfg cdef const LexemeC* get(self, str string) except NULL diff --git a/spacy/vocab.pyx b/spacy/vocab.pyx index a87f50ad4..5e8010e94 100644 --- a/spacy/vocab.pyx +++ b/spacy/vocab.pyx @@ -5,6 +5,7 @@ import numpy import srsly from thinc.api import get_array_module, get_current_ops import functools +import inspect from .lexeme cimport EMPTY_LEXEME, OOV_RANK from .lexeme cimport Lexeme @@ -14,29 +15,26 @@ from .attrs cimport LANG, ORTH from .compat import copy_reg from .errors import Errors -from .attrs import intify_attrs, NORM, IS_STOP +from .attrs import intify_attrs, NORM from .vectors import Vectors, Mode as VectorsMode from .util import registry from .lookups import Lookups from . import util from .lang.norm_exceptions import BASE_NORMS -from .lang.lex_attrs import LEX_ATTRS, is_stop, get_lang +from .lang.lex_attrs import LEX_ATTRS def create_vocab(lang, defaults, vectors_name=None): # If the spacy-lookups-data package is installed, we pre-populate the lookups # with lexeme data, if available lex_attrs = {**LEX_ATTRS, **defaults.lex_attr_getters} - # This is messy, but it's the minimal working fix to Issue #639. - lex_attrs[IS_STOP] = functools.partial(is_stop, stops=defaults.stop_words) - # Ensure that getter can be pickled - lex_attrs[LANG] = functools.partial(get_lang, lang=lang) - lex_attrs[NORM] = util.add_lookups( - lex_attrs.get(NORM, LEX_ATTRS[NORM]), - BASE_NORMS, - ) + defaults.lex_attr_data["stops"] = list(defaults.stop_words) + lookups = Lookups() + lookups.add_table("lexeme_norm", BASE_NORMS) return Vocab( + lang=lang, lex_attr_getters=lex_attrs, + lex_attr_data=defaults.lex_attr_data, writing_system=defaults.writing_system, get_noun_chunks=defaults.syntax_iterators.get("noun_chunks"), vectors_name=vectors_name, @@ -52,7 +50,8 @@ cdef class Vocab: """ def __init__(self, lex_attr_getters=None, strings=tuple(), lookups=None, oov_prob=-20., vectors_name=None, writing_system={}, - get_noun_chunks=None, **deprecated_kwargs): + get_noun_chunks=None, lang="", lex_attr_data=None, + **deprecated_kwargs): """Create the vocabulary. lex_attr_getters (dict): A dictionary mapping attribute IDs to @@ -66,9 +65,11 @@ cdef class Vocab: A function that yields base noun phrases used for Doc.noun_chunks. """ lex_attr_getters = lex_attr_getters if lex_attr_getters is not None else {} + lex_attr_data = lex_attr_data if lex_attr_data is not None else {} if lookups in (None, True, False): lookups = Lookups() self.cfg = {'oov_prob': oov_prob} + self.lang = lang self.mem = Pool() self._by_orth = PreshMap() self.strings = StringStore() @@ -77,6 +78,7 @@ cdef class Vocab: for string in strings: _ = self[string] self.lex_attr_getters = lex_attr_getters + self.lex_attr_data = lex_attr_data self.morphology = Morphology(self.strings) self.vectors = Vectors(strings=self.strings, name=vectors_name) self.lookups = lookups @@ -93,13 +95,6 @@ cdef class Vocab: self._vectors = vectors self._vectors.strings = self.strings - @property - def lang(self): - langfunc = None - if self.lex_attr_getters: - langfunc = self.lex_attr_getters.get(LANG, None) - return langfunc("_") if langfunc else "" - def __len__(self): """The current number of lexemes stored. @@ -183,7 +178,7 @@ cdef class Vocab: lex.id = OOV_RANK if self.lex_attr_getters is not None: for attr, func in self.lex_attr_getters.items(): - value = func(string) + value = _get_lex_attr_value(self, func, string) if isinstance(value, str): value = self.strings.add(value) if value is not None: @@ -433,12 +428,23 @@ cdef class Vocab: def __set__(self, lookups): self._lookups = lookups - if lookups.has_table("lexeme_norm"): - self.lex_attr_getters[NORM] = util.add_lookups( - self.lex_attr_getters.get(NORM, LEX_ATTRS[NORM]), - self.lookups.get_table("lexeme_norm"), - ) + self._reset_lexeme_cache() + property lex_attr_getters: + def __get__(self): + return self._lex_attr_getters + + def __set__(self, lex_attr_getters): + self._lex_attr_getters = lex_attr_getters + self._reset_lexeme_cache() + + property lex_attr_data: + def __get__(self): + return self._lex_attr_data + + def __set__(self, lex_attr_data): + self._lex_attr_data = lex_attr_data + self._reset_lexeme_cache() def to_disk(self, path, *, exclude=tuple()): """Save the current state to a directory. @@ -459,6 +465,8 @@ cdef class Vocab: self.vectors.to_disk(path, exclude=["strings"]) if "lookups" not in exclude: self.lookups.to_disk(path) + if "lex_attr_data" not in exclude: + srsly.write_msgpack(path / "lex_attr_data", self.lex_attr_data) def from_disk(self, path, *, exclude=tuple()): """Loads state from a directory. Modifies the object in place and @@ -471,7 +479,6 @@ cdef class Vocab: DOCS: https://spacy.io/api/vocab#to_disk """ path = util.ensure_path(path) - getters = ["strings", "vectors"] if "strings" not in exclude: self.strings.from_disk(path / "strings.json") # TODO: add exclude? if "vectors" not in exclude: @@ -479,12 +486,9 @@ cdef class Vocab: self.vectors.from_disk(path, exclude=["strings"]) if "lookups" not in exclude: self.lookups.from_disk(path) - if "lexeme_norm" in self.lookups: - self.lex_attr_getters[NORM] = util.add_lookups( - self.lex_attr_getters.get(NORM, LEX_ATTRS[NORM]), self.lookups.get_table("lexeme_norm") - ) - self.length = 0 - self._by_orth = PreshMap() + if "lex_attr_data" not in exclude: + self.lex_attr_data = srsly.read_msgpack(path / "lex_attr_data") + self._reset_lexeme_cache() return self def to_bytes(self, *, exclude=tuple()): @@ -505,6 +509,7 @@ cdef class Vocab: "strings": lambda: self.strings.to_bytes(), "vectors": deserialize_vectors, "lookups": lambda: self.lookups.to_bytes(), + "lex_attr_data": lambda: srsly.msgpack_dumps(self.lex_attr_data) } return util.to_bytes(getters, exclude) @@ -523,24 +528,27 @@ cdef class Vocab: else: return self.vectors.from_bytes(b, exclude=["strings"]) + def serialize_lex_attr_data(b): + self.lex_attr_data = srsly.msgpack_loads(b) + setters = { "strings": lambda b: self.strings.from_bytes(b), "vectors": lambda b: serialize_vectors(b), "lookups": lambda b: self.lookups.from_bytes(b), + "lex_attr_data": lambda b: serialize_lex_attr_data(b), } util.from_bytes(bytes_data, setters, exclude) - if "lexeme_norm" in self.lookups: - self.lex_attr_getters[NORM] = util.add_lookups( - self.lex_attr_getters.get(NORM, LEX_ATTRS[NORM]), self.lookups.get_table("lexeme_norm") - ) - self.length = 0 - self._by_orth = PreshMap() + self._reset_lexeme_cache() return self def _reset_cache(self, keys, strings): # I'm not sure this made sense. Disable it for now. raise NotImplementedError + def _reset_lexeme_cache(self): + self.length = 0 + self._by_orth = PreshMap() + def pickle_vocab(vocab): sstore = vocab.strings @@ -565,3 +573,12 @@ def unpickle_vocab(sstore, vectors, morphology, lex_attr_getters, lookups, get_n copy_reg.pickle(Vocab, pickle_vocab, unpickle_vocab) + + +def _get_lex_attr_value(vocab, func, string): + if "vocab" in inspect.signature(func).parameters: + value = func(vocab, string) + else: + # TODO: add deprecation warning + value = func(string) + return value