Sketch of lex_attr_getters taking the vocab

This commit is contained in:
Adriane Boyd 2023-02-06 17:09:52 +01:00
parent 6920fb7baf
commit b9a91120b2
11 changed files with 112 additions and 125 deletions

View File

@ -23,21 +23,21 @@ _tlds = set(
) )
def is_punct(text: str) -> bool: def is_punct(vocab, text: str) -> bool:
for char in text: for char in text:
if not unicodedata.category(char).startswith("P"): if not unicodedata.category(char).startswith("P"):
return False return False
return True return True
def is_ascii(text: str) -> bool: def is_ascii(vocab, text: str) -> bool:
for char in text: for char in text:
if ord(char) >= 128: if ord(char) >= 128:
return False return False
return True return True
def like_num(text: str) -> bool: def like_num(vocab, text: str) -> bool:
if text.startswith(("+", "-", "±", "~")): if text.startswith(("+", "-", "±", "~")):
text = text[1:] text = text[1:]
# can be overwritten by lang with list of number words # can be overwritten by lang with list of number words
@ -51,31 +51,31 @@ def like_num(text: str) -> bool:
return False return False
def is_bracket(text: str) -> bool: def is_bracket(vocab, text: str) -> bool:
brackets = ("(", ")", "[", "]", "{", "}", "<", ">") brackets = ("(", ")", "[", "]", "{", "}", "<", ">")
return text in brackets return text in brackets
def is_quote(text: str) -> bool: def is_quote(vocab, text: str) -> bool:
# fmt: off # fmt: off
quotes = ('"', "'", "`", "«", "»", "", "", "", "", "", "", "", "", "", "", "", "", "''", "``") quotes = ('"', "'", "`", "«", "»", "", "", "", "", "", "", "", "", "", "", "", "", "''", "``")
# fmt: on # fmt: on
return text in quotes return text in quotes
def is_left_punct(text: str) -> bool: def is_left_punct(vocab, text: str) -> bool:
# fmt: off # fmt: off
left_punct = ("(", "[", "{", "<", '"', "'", "«", "", "", "", "", "", "", "", "", "``") left_punct = ("(", "[", "{", "<", '"', "'", "«", "", "", "", "", "", "", "", "", "``")
# fmt: on # fmt: on
return text in left_punct return text in left_punct
def is_right_punct(text: str) -> bool: def is_right_punct(vocab, text: str) -> bool:
right_punct = (")", "]", "}", ">", '"', "'", "»", "", "", "", "", "''") right_punct = (")", "]", "}", ">", '"', "'", "»", "", "", "", "", "''")
return text in right_punct return text in right_punct
def is_currency(text: str) -> bool: def is_currency(vocab, text: str) -> bool:
# can be overwritten by lang with list of currency words, e.g. dollar, euro # can be overwritten by lang with list of currency words, e.g. dollar, euro
for char in text: for char in text:
if unicodedata.category(char) != "Sc": if unicodedata.category(char) != "Sc":
@ -83,11 +83,11 @@ def is_currency(text: str) -> bool:
return True return True
def like_email(text: str) -> bool: def like_email(vocab, text: str) -> bool:
return bool(_like_email(text)) return bool(_like_email(text))
def like_url(text: str) -> bool: def like_url(vocab, text: str) -> bool:
# We're looking for things that function in text like URLs. So, valid URL # We're looking for things that function in text like URLs. So, valid URL
# or not, anything they say http:// is going to be good. # or not, anything they say http:// is going to be good.
if text.startswith("http://") or text.startswith("https://"): if text.startswith("http://") or text.startswith("https://"):
@ -115,7 +115,7 @@ def like_url(text: str) -> bool:
return False return False
def word_shape(text: str) -> str: def word_shape(vocab, text: str) -> str:
if len(text) >= 100: if len(text) >= 100:
return "LONG" return "LONG"
shape = [] shape = []
@ -142,55 +142,54 @@ def word_shape(text: str) -> str:
return "".join(shape) return "".join(shape)
def lower(string: str) -> str: def lower(vocab, string: str) -> str:
return string.lower() return string.lower()
def prefix(string: str) -> str: def prefix(vocab, string: str) -> str:
return string[0] return string[0]
def suffix(string: str) -> str: def suffix(vocab, string: str) -> str:
return string[-3:] return string[-3:]
def is_alpha(string: str) -> bool: def is_alpha(vocab, string: str) -> bool:
return string.isalpha() return string.isalpha()
def is_digit(string: str) -> bool: def is_digit(vocab, string: str) -> bool:
return string.isdigit() return string.isdigit()
def is_lower(string: str) -> bool: def is_lower(vocab, string: str) -> bool:
return string.islower() return string.islower()
def is_space(string: str) -> bool: def is_space(vocab, string: str) -> bool:
return string.isspace() return string.isspace()
def is_title(string: str) -> bool: def is_title(vocab, string: str) -> bool:
return string.istitle() return string.istitle()
def is_upper(string: str) -> bool: def is_upper(vocab, string: str) -> bool:
return string.isupper() return string.isupper()
def is_stop(string: str, stops: Set[str] = set()) -> bool: def is_stop(vocab, string: str) -> bool:
stops = vocab.lex_attr_data.get("stops", [])
return string.lower() in stops return string.lower() in stops
def get_lang(text: str, lang: str = "") -> str: def norm(vocab, string: str) -> str:
# This function is partially applied so lang code can be passed in return vocab.lookups.get_table("lexeme_norm", {}).get(string, string.lower())
# automatically while still allowing pickling
return lang
LEX_ATTRS = { LEX_ATTRS = {
attrs.LOWER: lower, attrs.LOWER: lower,
attrs.NORM: lower, attrs.NORM: norm,
attrs.PREFIX: prefix, attrs.PREFIX: prefix,
attrs.SUFFIX: suffix, attrs.SUFFIX: suffix,
attrs.IS_ALPHA: is_alpha, attrs.IS_ALPHA: is_alpha,

View File

@ -74,6 +74,7 @@ class BaseDefaults:
url_match: Optional[Callable] = URL_MATCH url_match: Optional[Callable] = URL_MATCH
syntax_iterators: Dict[str, Callable] = {} syntax_iterators: Dict[str, Callable] = {}
lex_attr_getters: Dict[int, Callable[[str], Any]] = {} lex_attr_getters: Dict[int, Callable[[str], Any]] = {}
lex_attr_data: Dict[str, Any] = {}
stop_words: Set[str] = set() stop_words: Set[str] = set()
writing_system = {"direction": "ltr", "has_case": True, "has_letters": True} writing_system = {"direction": "ltr", "has_case": True, "has_letters": True}

View File

@ -2,7 +2,7 @@ from numpy cimport ndarray
from .typedefs cimport attr_t, hash_t, flags_t, len_t, tag_t from .typedefs cimport attr_t, hash_t, flags_t, len_t, tag_t
from .attrs cimport attr_id_t from .attrs cimport attr_id_t
from .attrs cimport ID, ORTH, LOWER, NORM, SHAPE, PREFIX, SUFFIX, LENGTH, LANG from .attrs cimport ID, ORTH, LOWER, NORM, SHAPE, PREFIX, SUFFIX, LENGTH
from .structs cimport LexemeC from .structs cimport LexemeC
from .vocab cimport Vocab from .vocab cimport Vocab
@ -40,8 +40,6 @@ cdef class Lexeme:
lex.prefix = value lex.prefix = value
elif name == SUFFIX: elif name == SUFFIX:
lex.suffix = value lex.suffix = value
elif name == LANG:
lex.lang = value
@staticmethod @staticmethod
cdef inline attr_t get_struct_attr(const LexemeC* lex, attr_id_t feat_name) nogil: cdef inline attr_t get_struct_attr(const LexemeC* lex, attr_id_t feat_name) nogil:
@ -66,8 +64,6 @@ cdef class Lexeme:
return lex.suffix return lex.suffix
elif feat_name == LENGTH: elif feat_name == LENGTH:
return lex.length return lex.length
elif feat_name == LANG:
return lex.lang
else: else:
return 0 return 0

View File

@ -247,13 +247,9 @@ cdef class Lexeme:
cluster_table = self.vocab.lookups.get_table("lexeme_cluster", {}) cluster_table = self.vocab.lookups.get_table("lexeme_cluster", {})
cluster_table[self.c.orth] = x cluster_table[self.c.orth] = x
property lang: @property
"""RETURNS (uint64): Language of the parent vocabulary.""" def lang(self):
def __get__(self): return self.vocab.strings[self.vocab.lang]
return self.c.lang
def __set__(self, attr_t x):
self.c.lang = x
property prob: property prob:
"""RETURNS (float): Smoothed log probability estimate of the lexeme's """RETURNS (float): Smoothed log probability estimate of the lexeme's
@ -316,13 +312,9 @@ cdef class Lexeme:
def __set__(self, str x): def __set__(self, str x):
self.c.suffix = self.vocab.strings.add(x) self.c.suffix = self.vocab.strings.add(x)
property lang_: @property
"""RETURNS (str): Language of the parent vocabulary.""" def lang_(self):
def __get__(self): return self.vocab.lang
return self.vocab.strings[self.c.lang]
def __set__(self, str x):
self.c.lang = self.vocab.strings.add(x)
property flags: property flags:
"""RETURNS (uint64): Container of the lexeme's binary flags.""" """RETURNS (uint64): Container of the lexeme's binary flags."""

View File

@ -11,8 +11,6 @@ from .parts_of_speech cimport univ_pos_t
cdef struct LexemeC: cdef struct LexemeC:
flags_t flags flags_t flags
attr_t lang
attr_t id attr_t id
attr_t length attr_t length

View File

@ -9,8 +9,8 @@ from spacy.lang.lex_attrs import like_url, word_shape
@pytest.mark.parametrize("word", ["the"]) @pytest.mark.parametrize("word", ["the"])
@pytest.mark.issue(1889) @pytest.mark.issue(1889)
def test_issue1889(word): def test_issue1889(en_vocab, word):
assert is_stop(word, STOP_WORDS) == is_stop(word.upper(), STOP_WORDS) assert is_stop(en_vocab, word) == is_stop(en_vocab, word.upper())
@pytest.mark.parametrize("text", ["dog"]) @pytest.mark.parametrize("text", ["dog"])
@ -59,13 +59,13 @@ def test_attrs_ent_iob_intify():
@pytest.mark.parametrize("text,match", [(",", True), (" ", False), ("a", False)]) @pytest.mark.parametrize("text,match", [(",", True), (" ", False), ("a", False)])
def test_lex_attrs_is_punct(text, match): def test_lex_attrs_is_punct(en_vocab, text, match):
assert is_punct(text) == match assert is_punct(en_vocab, text) == match
@pytest.mark.parametrize("text,match", [(",", True), ("£", False), ("", False)]) @pytest.mark.parametrize("text,match", [(",", True), ("£", False), ("", False)])
def test_lex_attrs_is_ascii(text, match): def test_lex_attrs_is_ascii(en_vocab, text, match):
assert is_ascii(text) == match assert is_ascii(en_vocab, text) == match
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -82,8 +82,8 @@ def test_lex_attrs_is_ascii(text, match):
("dog", False), ("dog", False),
], ],
) )
def test_lex_attrs_is_currency(text, match): def test_lex_attrs_is_currency(en_vocab, text, match):
assert is_currency(text) == match assert is_currency(en_vocab, text) == match
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -102,8 +102,8 @@ def test_lex_attrs_is_currency(text, match):
("hello.There", False), ("hello.There", False),
], ],
) )
def test_lex_attrs_like_url(text, match): def test_lex_attrs_like_url(en_vocab, text, match):
assert like_url(text) == match assert like_url(en_vocab, text) == match
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -118,5 +118,5 @@ def test_lex_attrs_like_url(text, match):
("``,-", "``,-"), ("``,-", "``,-"),
], ],
) )
def test_lex_attrs_word_shape(text, shape): def test_lex_attrs_word_shape(en_vocab, text, shape):
assert word_shape(text) == shape assert word_shape(en_vocab, text) == shape

View File

@ -6,10 +6,12 @@ from thinc.api import get_current_ops
import spacy import spacy
from spacy.lang.en import English from spacy.lang.en import English
from spacy.strings import StringStore from spacy.strings import StringStore
from spacy.symbols import NORM
from spacy.tokens import Doc from spacy.tokens import Doc
from spacy.util import ensure_path, load_model from spacy.util import ensure_path, load_model
from spacy.vectors import Vectors from spacy.vectors import Vectors
from spacy.vocab import Vocab from spacy.vocab import Vocab
from spacy.lang.lex_attrs import norm
from ..util import make_tempdir from ..util import make_tempdir
@ -73,6 +75,7 @@ def test_serialize_vocab(en_vocab, text):
new_vocab = Vocab().from_bytes(vocab_bytes) new_vocab = Vocab().from_bytes(vocab_bytes)
assert new_vocab.strings[text_hash] == text assert new_vocab.strings[text_hash] == text
assert new_vocab.to_bytes(exclude=["lookups"]) == vocab_bytes assert new_vocab.to_bytes(exclude=["lookups"]) == vocab_bytes
assert new_vocab.lex_attr_data == en_vocab.lex_attr_data
@pytest.mark.parametrize("strings1,strings2", test_strings) @pytest.mark.parametrize("strings1,strings2", test_strings)
@ -111,12 +114,13 @@ def test_serialize_vocab_roundtrip_disk(strings1, strings2):
assert [s for s in vocab1_d.strings] == [s for s in vocab2_d.strings] assert [s for s in vocab1_d.strings] == [s for s in vocab2_d.strings]
else: else:
assert [s for s in vocab1_d.strings] != [s for s in vocab2_d.strings] assert [s for s in vocab1_d.strings] != [s for s in vocab2_d.strings]
assert vocab1_d.lex_attr_data == vocab2_d.lex_attr_data
@pytest.mark.parametrize("strings,lex_attr", test_strings_attrs) @pytest.mark.parametrize("strings,lex_attr", test_strings_attrs)
def test_serialize_vocab_lex_attrs_bytes(strings, lex_attr): def test_serialize_vocab_lex_attrs_bytes(strings, lex_attr):
vocab1 = Vocab(strings=strings) vocab1 = Vocab(strings=strings, lex_attr_getters={NORM: norm})
vocab2 = Vocab() vocab2 = Vocab(lex_attr_getters={NORM: norm})
vocab1[strings[0]].norm_ = lex_attr vocab1[strings[0]].norm_ = lex_attr
assert vocab1[strings[0]].norm_ == lex_attr assert vocab1[strings[0]].norm_ == lex_attr
assert vocab2[strings[0]].norm_ != lex_attr assert vocab2[strings[0]].norm_ != lex_attr
@ -134,8 +138,8 @@ def test_deserialize_vocab_seen_entries(strings, lex_attr):
@pytest.mark.parametrize("strings,lex_attr", test_strings_attrs) @pytest.mark.parametrize("strings,lex_attr", test_strings_attrs)
def test_serialize_vocab_lex_attrs_disk(strings, lex_attr): def test_serialize_vocab_lex_attrs_disk(strings, lex_attr):
vocab1 = Vocab(strings=strings) vocab1 = Vocab(strings=strings, lex_attr_getters={NORM: norm})
vocab2 = Vocab() vocab2 = Vocab(lex_attr_getters={NORM: norm})
vocab1[strings[0]].norm_ = lex_attr vocab1[strings[0]].norm_ = lex_attr
assert vocab1[strings[0]].norm_ == lex_attr assert vocab1[strings[0]].norm_ == lex_attr
assert vocab2[strings[0]].norm_ != lex_attr assert vocab2[strings[0]].norm_ != lex_attr

View File

@ -288,7 +288,7 @@ cdef class Token:
"""RETURNS (uint64): ID of the language of the parent document's """RETURNS (uint64): ID of the language of the parent document's
vocabulary. vocabulary.
""" """
return self.c.lex.lang return self.doc.lang
@property @property
def idx(self): def idx(self):
@ -844,7 +844,7 @@ cdef class Token:
"""RETURNS (str): Language of the parent document's vocabulary, """RETURNS (str): Language of the parent document's vocabulary,
e.g. 'en'. e.g. 'en'.
""" """
return self.vocab.strings[self.c.lex.lang] return self.doc.lang_
property lemma_: property lemma_:
"""RETURNS (str): The token lemma, i.e. the base form of the word, """RETURNS (str): The token lemma, i.e. the base form of the word,

View File

@ -1160,28 +1160,6 @@ def compile_infix_regex(entries: Iterable[Union[str, Pattern]]) -> Pattern:
return re.compile(expression) return re.compile(expression)
def add_lookups(default_func: Callable[[str], Any], *lookups) -> Callable[[str], Any]:
"""Extend an attribute function with special cases. If a word is in the
lookups, the value is returned. Otherwise the previous function is used.
default_func (callable): The default function to execute.
*lookups (dict): Lookup dictionary mapping string to attribute value.
RETURNS (callable): Lexical attribute getter.
"""
# This is implemented as functools.partial instead of a closure, to allow
# pickle to work.
return functools.partial(_get_attr_unless_lookup, default_func, lookups)
def _get_attr_unless_lookup(
default_func: Callable[[str], Any], lookups: Dict[str, Any], string: str
) -> Any:
for lookup in lookups:
if string in lookup:
return lookup[string] # type: ignore[index]
return default_func(string)
def update_exc( def update_exc(
base_exceptions: Dict[str, List[dict]], *addition_dicts base_exceptions: Dict[str, List[dict]], *addition_dicts
) -> Dict[str, List[dict]]: ) -> Dict[str, List[dict]]:

View File

@ -27,12 +27,14 @@ cdef class Vocab:
cdef Pool mem cdef Pool mem
cdef readonly StringStore strings cdef readonly StringStore strings
cdef public Morphology morphology cdef public Morphology morphology
cdef public str lang
cdef public object _vectors cdef public object _vectors
cdef public object _lookups cdef public object _lookups
cdef public object _lex_attr_getters
cdef public object _lex_attr_data
cdef public object writing_system cdef public object writing_system
cdef public object get_noun_chunks cdef public object get_noun_chunks
cdef readonly int length cdef readonly int length
cdef public object lex_attr_getters
cdef public object cfg cdef public object cfg
cdef const LexemeC* get(self, str string) except NULL cdef const LexemeC* get(self, str string) except NULL

View File

@ -5,6 +5,7 @@ import numpy
import srsly import srsly
from thinc.api import get_array_module, get_current_ops from thinc.api import get_array_module, get_current_ops
import functools import functools
import inspect
from .lexeme cimport EMPTY_LEXEME, OOV_RANK from .lexeme cimport EMPTY_LEXEME, OOV_RANK
from .lexeme cimport Lexeme from .lexeme cimport Lexeme
@ -14,29 +15,26 @@ from .attrs cimport LANG, ORTH
from .compat import copy_reg from .compat import copy_reg
from .errors import Errors from .errors import Errors
from .attrs import intify_attrs, NORM, IS_STOP from .attrs import intify_attrs, NORM
from .vectors import Vectors, Mode as VectorsMode from .vectors import Vectors, Mode as VectorsMode
from .util import registry from .util import registry
from .lookups import Lookups from .lookups import Lookups
from . import util from . import util
from .lang.norm_exceptions import BASE_NORMS from .lang.norm_exceptions import BASE_NORMS
from .lang.lex_attrs import LEX_ATTRS, is_stop, get_lang from .lang.lex_attrs import LEX_ATTRS
def create_vocab(lang, defaults, vectors_name=None): def create_vocab(lang, defaults, vectors_name=None):
# If the spacy-lookups-data package is installed, we pre-populate the lookups # If the spacy-lookups-data package is installed, we pre-populate the lookups
# with lexeme data, if available # with lexeme data, if available
lex_attrs = {**LEX_ATTRS, **defaults.lex_attr_getters} lex_attrs = {**LEX_ATTRS, **defaults.lex_attr_getters}
# This is messy, but it's the minimal working fix to Issue #639. defaults.lex_attr_data["stops"] = list(defaults.stop_words)
lex_attrs[IS_STOP] = functools.partial(is_stop, stops=defaults.stop_words) lookups = Lookups()
# Ensure that getter can be pickled lookups.add_table("lexeme_norm", BASE_NORMS)
lex_attrs[LANG] = functools.partial(get_lang, lang=lang)
lex_attrs[NORM] = util.add_lookups(
lex_attrs.get(NORM, LEX_ATTRS[NORM]),
BASE_NORMS,
)
return Vocab( return Vocab(
lang=lang,
lex_attr_getters=lex_attrs, lex_attr_getters=lex_attrs,
lex_attr_data=defaults.lex_attr_data,
writing_system=defaults.writing_system, writing_system=defaults.writing_system,
get_noun_chunks=defaults.syntax_iterators.get("noun_chunks"), get_noun_chunks=defaults.syntax_iterators.get("noun_chunks"),
vectors_name=vectors_name, vectors_name=vectors_name,
@ -52,7 +50,8 @@ cdef class Vocab:
""" """
def __init__(self, lex_attr_getters=None, strings=tuple(), lookups=None, def __init__(self, lex_attr_getters=None, strings=tuple(), lookups=None,
oov_prob=-20., vectors_name=None, writing_system={}, oov_prob=-20., vectors_name=None, writing_system={},
get_noun_chunks=None, **deprecated_kwargs): get_noun_chunks=None, lang="", lex_attr_data=None,
**deprecated_kwargs):
"""Create the vocabulary. """Create the vocabulary.
lex_attr_getters (dict): A dictionary mapping attribute IDs to lex_attr_getters (dict): A dictionary mapping attribute IDs to
@ -66,9 +65,11 @@ cdef class Vocab:
A function that yields base noun phrases used for Doc.noun_chunks. A function that yields base noun phrases used for Doc.noun_chunks.
""" """
lex_attr_getters = lex_attr_getters if lex_attr_getters is not None else {} lex_attr_getters = lex_attr_getters if lex_attr_getters is not None else {}
lex_attr_data = lex_attr_data if lex_attr_data is not None else {}
if lookups in (None, True, False): if lookups in (None, True, False):
lookups = Lookups() lookups = Lookups()
self.cfg = {'oov_prob': oov_prob} self.cfg = {'oov_prob': oov_prob}
self.lang = lang
self.mem = Pool() self.mem = Pool()
self._by_orth = PreshMap() self._by_orth = PreshMap()
self.strings = StringStore() self.strings = StringStore()
@ -77,6 +78,7 @@ cdef class Vocab:
for string in strings: for string in strings:
_ = self[string] _ = self[string]
self.lex_attr_getters = lex_attr_getters self.lex_attr_getters = lex_attr_getters
self.lex_attr_data = lex_attr_data
self.morphology = Morphology(self.strings) self.morphology = Morphology(self.strings)
self.vectors = Vectors(strings=self.strings, name=vectors_name) self.vectors = Vectors(strings=self.strings, name=vectors_name)
self.lookups = lookups self.lookups = lookups
@ -93,13 +95,6 @@ cdef class Vocab:
self._vectors = vectors self._vectors = vectors
self._vectors.strings = self.strings self._vectors.strings = self.strings
@property
def lang(self):
langfunc = None
if self.lex_attr_getters:
langfunc = self.lex_attr_getters.get(LANG, None)
return langfunc("_") if langfunc else ""
def __len__(self): def __len__(self):
"""The current number of lexemes stored. """The current number of lexemes stored.
@ -183,7 +178,7 @@ cdef class Vocab:
lex.id = OOV_RANK lex.id = OOV_RANK
if self.lex_attr_getters is not None: if self.lex_attr_getters is not None:
for attr, func in self.lex_attr_getters.items(): for attr, func in self.lex_attr_getters.items():
value = func(string) value = _get_lex_attr_value(self, func, string)
if isinstance(value, str): if isinstance(value, str):
value = self.strings.add(value) value = self.strings.add(value)
if value is not None: if value is not None:
@ -433,12 +428,23 @@ cdef class Vocab:
def __set__(self, lookups): def __set__(self, lookups):
self._lookups = lookups self._lookups = lookups
if lookups.has_table("lexeme_norm"): self._reset_lexeme_cache()
self.lex_attr_getters[NORM] = util.add_lookups(
self.lex_attr_getters.get(NORM, LEX_ATTRS[NORM]),
self.lookups.get_table("lexeme_norm"),
)
property lex_attr_getters:
def __get__(self):
return self._lex_attr_getters
def __set__(self, lex_attr_getters):
self._lex_attr_getters = lex_attr_getters
self._reset_lexeme_cache()
property lex_attr_data:
def __get__(self):
return self._lex_attr_data
def __set__(self, lex_attr_data):
self._lex_attr_data = lex_attr_data
self._reset_lexeme_cache()
def to_disk(self, path, *, exclude=tuple()): def to_disk(self, path, *, exclude=tuple()):
"""Save the current state to a directory. """Save the current state to a directory.
@ -459,6 +465,8 @@ cdef class Vocab:
self.vectors.to_disk(path, exclude=["strings"]) self.vectors.to_disk(path, exclude=["strings"])
if "lookups" not in exclude: if "lookups" not in exclude:
self.lookups.to_disk(path) self.lookups.to_disk(path)
if "lex_attr_data" not in exclude:
srsly.write_msgpack(path / "lex_attr_data", self.lex_attr_data)
def from_disk(self, path, *, exclude=tuple()): def from_disk(self, path, *, exclude=tuple()):
"""Loads state from a directory. Modifies the object in place and """Loads state from a directory. Modifies the object in place and
@ -471,7 +479,6 @@ cdef class Vocab:
DOCS: https://spacy.io/api/vocab#to_disk DOCS: https://spacy.io/api/vocab#to_disk
""" """
path = util.ensure_path(path) path = util.ensure_path(path)
getters = ["strings", "vectors"]
if "strings" not in exclude: if "strings" not in exclude:
self.strings.from_disk(path / "strings.json") # TODO: add exclude? self.strings.from_disk(path / "strings.json") # TODO: add exclude?
if "vectors" not in exclude: if "vectors" not in exclude:
@ -479,12 +486,9 @@ cdef class Vocab:
self.vectors.from_disk(path, exclude=["strings"]) self.vectors.from_disk(path, exclude=["strings"])
if "lookups" not in exclude: if "lookups" not in exclude:
self.lookups.from_disk(path) self.lookups.from_disk(path)
if "lexeme_norm" in self.lookups: if "lex_attr_data" not in exclude:
self.lex_attr_getters[NORM] = util.add_lookups( self.lex_attr_data = srsly.read_msgpack(path / "lex_attr_data")
self.lex_attr_getters.get(NORM, LEX_ATTRS[NORM]), self.lookups.get_table("lexeme_norm") self._reset_lexeme_cache()
)
self.length = 0
self._by_orth = PreshMap()
return self return self
def to_bytes(self, *, exclude=tuple()): def to_bytes(self, *, exclude=tuple()):
@ -505,6 +509,7 @@ cdef class Vocab:
"strings": lambda: self.strings.to_bytes(), "strings": lambda: self.strings.to_bytes(),
"vectors": deserialize_vectors, "vectors": deserialize_vectors,
"lookups": lambda: self.lookups.to_bytes(), "lookups": lambda: self.lookups.to_bytes(),
"lex_attr_data": lambda: srsly.msgpack_dumps(self.lex_attr_data)
} }
return util.to_bytes(getters, exclude) return util.to_bytes(getters, exclude)
@ -523,24 +528,27 @@ cdef class Vocab:
else: else:
return self.vectors.from_bytes(b, exclude=["strings"]) return self.vectors.from_bytes(b, exclude=["strings"])
def serialize_lex_attr_data(b):
self.lex_attr_data = srsly.msgpack_loads(b)
setters = { setters = {
"strings": lambda b: self.strings.from_bytes(b), "strings": lambda b: self.strings.from_bytes(b),
"vectors": lambda b: serialize_vectors(b), "vectors": lambda b: serialize_vectors(b),
"lookups": lambda b: self.lookups.from_bytes(b), "lookups": lambda b: self.lookups.from_bytes(b),
"lex_attr_data": lambda b: serialize_lex_attr_data(b),
} }
util.from_bytes(bytes_data, setters, exclude) util.from_bytes(bytes_data, setters, exclude)
if "lexeme_norm" in self.lookups: self._reset_lexeme_cache()
self.lex_attr_getters[NORM] = util.add_lookups(
self.lex_attr_getters.get(NORM, LEX_ATTRS[NORM]), self.lookups.get_table("lexeme_norm")
)
self.length = 0
self._by_orth = PreshMap()
return self return self
def _reset_cache(self, keys, strings): def _reset_cache(self, keys, strings):
# I'm not sure this made sense. Disable it for now. # I'm not sure this made sense. Disable it for now.
raise NotImplementedError raise NotImplementedError
def _reset_lexeme_cache(self):
self.length = 0
self._by_orth = PreshMap()
def pickle_vocab(vocab): def pickle_vocab(vocab):
sstore = vocab.strings sstore = vocab.strings
@ -565,3 +573,12 @@ def unpickle_vocab(sstore, vectors, morphology, lex_attr_getters, lookups, get_n
copy_reg.pickle(Vocab, pickle_vocab, unpickle_vocab) copy_reg.pickle(Vocab, pickle_vocab, unpickle_vocab)
def _get_lex_attr_value(vocab, func, string):
if "vocab" in inspect.signature(func).parameters:
value = func(vocab, string)
else:
# TODO: add deprecation warning
value = func(string)
return value