mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-13 01:32:32 +03:00
Sketch of lex_attr_getters taking the vocab
This commit is contained in:
parent
6920fb7baf
commit
b9a91120b2
|
@ -23,21 +23,21 @@ _tlds = set(
|
|||
)
|
||||
|
||||
|
||||
def is_punct(text: str) -> bool:
|
||||
def is_punct(vocab, text: str) -> bool:
|
||||
for char in text:
|
||||
if not unicodedata.category(char).startswith("P"):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def is_ascii(text: str) -> bool:
|
||||
def is_ascii(vocab, text: str) -> bool:
|
||||
for char in text:
|
||||
if ord(char) >= 128:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def like_num(text: str) -> bool:
|
||||
def like_num(vocab, text: str) -> bool:
|
||||
if text.startswith(("+", "-", "±", "~")):
|
||||
text = text[1:]
|
||||
# can be overwritten by lang with list of number words
|
||||
|
@ -51,31 +51,31 @@ def like_num(text: str) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
def is_bracket(text: str) -> bool:
|
||||
def is_bracket(vocab, text: str) -> bool:
|
||||
brackets = ("(", ")", "[", "]", "{", "}", "<", ">")
|
||||
return text in brackets
|
||||
|
||||
|
||||
def is_quote(text: str) -> bool:
|
||||
def is_quote(vocab, text: str) -> bool:
|
||||
# fmt: off
|
||||
quotes = ('"', "'", "`", "«", "»", "‘", "’", "‚", "‛", "“", "”", "„", "‟", "‹", "›", "❮", "❯", "''", "``")
|
||||
# fmt: on
|
||||
return text in quotes
|
||||
|
||||
|
||||
def is_left_punct(text: str) -> bool:
|
||||
def is_left_punct(vocab, text: str) -> bool:
|
||||
# fmt: off
|
||||
left_punct = ("(", "[", "{", "<", '"', "'", "«", "‘", "‚", "‛", "“", "„", "‟", "‹", "❮", "``")
|
||||
# fmt: on
|
||||
return text in left_punct
|
||||
|
||||
|
||||
def is_right_punct(text: str) -> bool:
|
||||
def is_right_punct(vocab, text: str) -> bool:
|
||||
right_punct = (")", "]", "}", ">", '"', "'", "»", "’", "”", "›", "❯", "''")
|
||||
return text in right_punct
|
||||
|
||||
|
||||
def is_currency(text: str) -> bool:
|
||||
def is_currency(vocab, text: str) -> bool:
|
||||
# can be overwritten by lang with list of currency words, e.g. dollar, euro
|
||||
for char in text:
|
||||
if unicodedata.category(char) != "Sc":
|
||||
|
@ -83,11 +83,11 @@ def is_currency(text: str) -> bool:
|
|||
return True
|
||||
|
||||
|
||||
def like_email(text: str) -> bool:
|
||||
def like_email(vocab, text: str) -> bool:
|
||||
return bool(_like_email(text))
|
||||
|
||||
|
||||
def like_url(text: str) -> bool:
|
||||
def like_url(vocab, text: str) -> bool:
|
||||
# We're looking for things that function in text like URLs. So, valid URL
|
||||
# or not, anything they say http:// is going to be good.
|
||||
if text.startswith("http://") or text.startswith("https://"):
|
||||
|
@ -115,7 +115,7 @@ def like_url(text: str) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
def word_shape(text: str) -> str:
|
||||
def word_shape(vocab, text: str) -> str:
|
||||
if len(text) >= 100:
|
||||
return "LONG"
|
||||
shape = []
|
||||
|
@ -142,55 +142,54 @@ def word_shape(text: str) -> str:
|
|||
return "".join(shape)
|
||||
|
||||
|
||||
def lower(string: str) -> str:
|
||||
def lower(vocab, string: str) -> str:
|
||||
return string.lower()
|
||||
|
||||
|
||||
def prefix(string: str) -> str:
|
||||
def prefix(vocab, string: str) -> str:
|
||||
return string[0]
|
||||
|
||||
|
||||
def suffix(string: str) -> str:
|
||||
def suffix(vocab, string: str) -> str:
|
||||
return string[-3:]
|
||||
|
||||
|
||||
def is_alpha(string: str) -> bool:
|
||||
def is_alpha(vocab, string: str) -> bool:
|
||||
return string.isalpha()
|
||||
|
||||
|
||||
def is_digit(string: str) -> bool:
|
||||
def is_digit(vocab, string: str) -> bool:
|
||||
return string.isdigit()
|
||||
|
||||
|
||||
def is_lower(string: str) -> bool:
|
||||
def is_lower(vocab, string: str) -> bool:
|
||||
return string.islower()
|
||||
|
||||
|
||||
def is_space(string: str) -> bool:
|
||||
def is_space(vocab, string: str) -> bool:
|
||||
return string.isspace()
|
||||
|
||||
|
||||
def is_title(string: str) -> bool:
|
||||
def is_title(vocab, string: str) -> bool:
|
||||
return string.istitle()
|
||||
|
||||
|
||||
def is_upper(string: str) -> bool:
|
||||
def is_upper(vocab, string: str) -> bool:
|
||||
return string.isupper()
|
||||
|
||||
|
||||
def is_stop(string: str, stops: Set[str] = set()) -> bool:
|
||||
def is_stop(vocab, string: str) -> bool:
|
||||
stops = vocab.lex_attr_data.get("stops", [])
|
||||
return string.lower() in stops
|
||||
|
||||
|
||||
def get_lang(text: str, lang: str = "") -> str:
|
||||
# This function is partially applied so lang code can be passed in
|
||||
# automatically while still allowing pickling
|
||||
return lang
|
||||
def norm(vocab, string: str) -> str:
|
||||
return vocab.lookups.get_table("lexeme_norm", {}).get(string, string.lower())
|
||||
|
||||
|
||||
LEX_ATTRS = {
|
||||
attrs.LOWER: lower,
|
||||
attrs.NORM: lower,
|
||||
attrs.NORM: norm,
|
||||
attrs.PREFIX: prefix,
|
||||
attrs.SUFFIX: suffix,
|
||||
attrs.IS_ALPHA: is_alpha,
|
||||
|
|
|
@ -74,6 +74,7 @@ class BaseDefaults:
|
|||
url_match: Optional[Callable] = URL_MATCH
|
||||
syntax_iterators: Dict[str, Callable] = {}
|
||||
lex_attr_getters: Dict[int, Callable[[str], Any]] = {}
|
||||
lex_attr_data: Dict[str, Any] = {}
|
||||
stop_words: Set[str] = set()
|
||||
writing_system = {"direction": "ltr", "has_case": True, "has_letters": True}
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@ from numpy cimport ndarray
|
|||
|
||||
from .typedefs cimport attr_t, hash_t, flags_t, len_t, tag_t
|
||||
from .attrs cimport attr_id_t
|
||||
from .attrs cimport ID, ORTH, LOWER, NORM, SHAPE, PREFIX, SUFFIX, LENGTH, LANG
|
||||
from .attrs cimport ID, ORTH, LOWER, NORM, SHAPE, PREFIX, SUFFIX, LENGTH
|
||||
|
||||
from .structs cimport LexemeC
|
||||
from .vocab cimport Vocab
|
||||
|
@ -40,8 +40,6 @@ cdef class Lexeme:
|
|||
lex.prefix = value
|
||||
elif name == SUFFIX:
|
||||
lex.suffix = value
|
||||
elif name == LANG:
|
||||
lex.lang = value
|
||||
|
||||
@staticmethod
|
||||
cdef inline attr_t get_struct_attr(const LexemeC* lex, attr_id_t feat_name) nogil:
|
||||
|
@ -66,8 +64,6 @@ cdef class Lexeme:
|
|||
return lex.suffix
|
||||
elif feat_name == LENGTH:
|
||||
return lex.length
|
||||
elif feat_name == LANG:
|
||||
return lex.lang
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
|
|
@ -247,13 +247,9 @@ cdef class Lexeme:
|
|||
cluster_table = self.vocab.lookups.get_table("lexeme_cluster", {})
|
||||
cluster_table[self.c.orth] = x
|
||||
|
||||
property lang:
|
||||
"""RETURNS (uint64): Language of the parent vocabulary."""
|
||||
def __get__(self):
|
||||
return self.c.lang
|
||||
|
||||
def __set__(self, attr_t x):
|
||||
self.c.lang = x
|
||||
@property
|
||||
def lang(self):
|
||||
return self.vocab.strings[self.vocab.lang]
|
||||
|
||||
property prob:
|
||||
"""RETURNS (float): Smoothed log probability estimate of the lexeme's
|
||||
|
@ -316,13 +312,9 @@ cdef class Lexeme:
|
|||
def __set__(self, str x):
|
||||
self.c.suffix = self.vocab.strings.add(x)
|
||||
|
||||
property lang_:
|
||||
"""RETURNS (str): Language of the parent vocabulary."""
|
||||
def __get__(self):
|
||||
return self.vocab.strings[self.c.lang]
|
||||
|
||||
def __set__(self, str x):
|
||||
self.c.lang = self.vocab.strings.add(x)
|
||||
@property
|
||||
def lang_(self):
|
||||
return self.vocab.lang
|
||||
|
||||
property flags:
|
||||
"""RETURNS (uint64): Container of the lexeme's binary flags."""
|
||||
|
|
|
@ -11,8 +11,6 @@ from .parts_of_speech cimport univ_pos_t
|
|||
cdef struct LexemeC:
|
||||
flags_t flags
|
||||
|
||||
attr_t lang
|
||||
|
||||
attr_t id
|
||||
attr_t length
|
||||
|
||||
|
|
|
@ -9,8 +9,8 @@ from spacy.lang.lex_attrs import like_url, word_shape
|
|||
|
||||
@pytest.mark.parametrize("word", ["the"])
|
||||
@pytest.mark.issue(1889)
|
||||
def test_issue1889(word):
|
||||
assert is_stop(word, STOP_WORDS) == is_stop(word.upper(), STOP_WORDS)
|
||||
def test_issue1889(en_vocab, word):
|
||||
assert is_stop(en_vocab, word) == is_stop(en_vocab, word.upper())
|
||||
|
||||
|
||||
@pytest.mark.parametrize("text", ["dog"])
|
||||
|
@ -59,13 +59,13 @@ def test_attrs_ent_iob_intify():
|
|||
|
||||
|
||||
@pytest.mark.parametrize("text,match", [(",", True), (" ", False), ("a", False)])
|
||||
def test_lex_attrs_is_punct(text, match):
|
||||
assert is_punct(text) == match
|
||||
def test_lex_attrs_is_punct(en_vocab, text, match):
|
||||
assert is_punct(en_vocab, text) == match
|
||||
|
||||
|
||||
@pytest.mark.parametrize("text,match", [(",", True), ("£", False), ("♥", False)])
|
||||
def test_lex_attrs_is_ascii(text, match):
|
||||
assert is_ascii(text) == match
|
||||
def test_lex_attrs_is_ascii(en_vocab, text, match):
|
||||
assert is_ascii(en_vocab, text) == match
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -82,8 +82,8 @@ def test_lex_attrs_is_ascii(text, match):
|
|||
("dog", False),
|
||||
],
|
||||
)
|
||||
def test_lex_attrs_is_currency(text, match):
|
||||
assert is_currency(text) == match
|
||||
def test_lex_attrs_is_currency(en_vocab, text, match):
|
||||
assert is_currency(en_vocab, text) == match
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -102,8 +102,8 @@ def test_lex_attrs_is_currency(text, match):
|
|||
("hello.There", False),
|
||||
],
|
||||
)
|
||||
def test_lex_attrs_like_url(text, match):
|
||||
assert like_url(text) == match
|
||||
def test_lex_attrs_like_url(en_vocab, text, match):
|
||||
assert like_url(en_vocab, text) == match
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -118,5 +118,5 @@ def test_lex_attrs_like_url(text, match):
|
|||
("``,-", "``,-"),
|
||||
],
|
||||
)
|
||||
def test_lex_attrs_word_shape(text, shape):
|
||||
assert word_shape(text) == shape
|
||||
def test_lex_attrs_word_shape(en_vocab, text, shape):
|
||||
assert word_shape(en_vocab, text) == shape
|
||||
|
|
|
@ -6,10 +6,12 @@ from thinc.api import get_current_ops
|
|||
import spacy
|
||||
from spacy.lang.en import English
|
||||
from spacy.strings import StringStore
|
||||
from spacy.symbols import NORM
|
||||
from spacy.tokens import Doc
|
||||
from spacy.util import ensure_path, load_model
|
||||
from spacy.vectors import Vectors
|
||||
from spacy.vocab import Vocab
|
||||
from spacy.lang.lex_attrs import norm
|
||||
|
||||
from ..util import make_tempdir
|
||||
|
||||
|
@ -73,6 +75,7 @@ def test_serialize_vocab(en_vocab, text):
|
|||
new_vocab = Vocab().from_bytes(vocab_bytes)
|
||||
assert new_vocab.strings[text_hash] == text
|
||||
assert new_vocab.to_bytes(exclude=["lookups"]) == vocab_bytes
|
||||
assert new_vocab.lex_attr_data == en_vocab.lex_attr_data
|
||||
|
||||
|
||||
@pytest.mark.parametrize("strings1,strings2", test_strings)
|
||||
|
@ -111,12 +114,13 @@ def test_serialize_vocab_roundtrip_disk(strings1, strings2):
|
|||
assert [s for s in vocab1_d.strings] == [s for s in vocab2_d.strings]
|
||||
else:
|
||||
assert [s for s in vocab1_d.strings] != [s for s in vocab2_d.strings]
|
||||
assert vocab1_d.lex_attr_data == vocab2_d.lex_attr_data
|
||||
|
||||
|
||||
@pytest.mark.parametrize("strings,lex_attr", test_strings_attrs)
|
||||
def test_serialize_vocab_lex_attrs_bytes(strings, lex_attr):
|
||||
vocab1 = Vocab(strings=strings)
|
||||
vocab2 = Vocab()
|
||||
vocab1 = Vocab(strings=strings, lex_attr_getters={NORM: norm})
|
||||
vocab2 = Vocab(lex_attr_getters={NORM: norm})
|
||||
vocab1[strings[0]].norm_ = lex_attr
|
||||
assert vocab1[strings[0]].norm_ == lex_attr
|
||||
assert vocab2[strings[0]].norm_ != lex_attr
|
||||
|
@ -134,8 +138,8 @@ def test_deserialize_vocab_seen_entries(strings, lex_attr):
|
|||
|
||||
@pytest.mark.parametrize("strings,lex_attr", test_strings_attrs)
|
||||
def test_serialize_vocab_lex_attrs_disk(strings, lex_attr):
|
||||
vocab1 = Vocab(strings=strings)
|
||||
vocab2 = Vocab()
|
||||
vocab1 = Vocab(strings=strings, lex_attr_getters={NORM: norm})
|
||||
vocab2 = Vocab(lex_attr_getters={NORM: norm})
|
||||
vocab1[strings[0]].norm_ = lex_attr
|
||||
assert vocab1[strings[0]].norm_ == lex_attr
|
||||
assert vocab2[strings[0]].norm_ != lex_attr
|
||||
|
|
|
@ -288,7 +288,7 @@ cdef class Token:
|
|||
"""RETURNS (uint64): ID of the language of the parent document's
|
||||
vocabulary.
|
||||
"""
|
||||
return self.c.lex.lang
|
||||
return self.doc.lang
|
||||
|
||||
@property
|
||||
def idx(self):
|
||||
|
@ -844,7 +844,7 @@ cdef class Token:
|
|||
"""RETURNS (str): Language of the parent document's vocabulary,
|
||||
e.g. 'en'.
|
||||
"""
|
||||
return self.vocab.strings[self.c.lex.lang]
|
||||
return self.doc.lang_
|
||||
|
||||
property lemma_:
|
||||
"""RETURNS (str): The token lemma, i.e. the base form of the word,
|
||||
|
|
|
@ -1160,28 +1160,6 @@ def compile_infix_regex(entries: Iterable[Union[str, Pattern]]) -> Pattern:
|
|||
return re.compile(expression)
|
||||
|
||||
|
||||
def add_lookups(default_func: Callable[[str], Any], *lookups) -> Callable[[str], Any]:
|
||||
"""Extend an attribute function with special cases. If a word is in the
|
||||
lookups, the value is returned. Otherwise the previous function is used.
|
||||
|
||||
default_func (callable): The default function to execute.
|
||||
*lookups (dict): Lookup dictionary mapping string to attribute value.
|
||||
RETURNS (callable): Lexical attribute getter.
|
||||
"""
|
||||
# This is implemented as functools.partial instead of a closure, to allow
|
||||
# pickle to work.
|
||||
return functools.partial(_get_attr_unless_lookup, default_func, lookups)
|
||||
|
||||
|
||||
def _get_attr_unless_lookup(
|
||||
default_func: Callable[[str], Any], lookups: Dict[str, Any], string: str
|
||||
) -> Any:
|
||||
for lookup in lookups:
|
||||
if string in lookup:
|
||||
return lookup[string] # type: ignore[index]
|
||||
return default_func(string)
|
||||
|
||||
|
||||
def update_exc(
|
||||
base_exceptions: Dict[str, List[dict]], *addition_dicts
|
||||
) -> Dict[str, List[dict]]:
|
||||
|
|
|
@ -27,12 +27,14 @@ cdef class Vocab:
|
|||
cdef Pool mem
|
||||
cdef readonly StringStore strings
|
||||
cdef public Morphology morphology
|
||||
cdef public str lang
|
||||
cdef public object _vectors
|
||||
cdef public object _lookups
|
||||
cdef public object _lex_attr_getters
|
||||
cdef public object _lex_attr_data
|
||||
cdef public object writing_system
|
||||
cdef public object get_noun_chunks
|
||||
cdef readonly int length
|
||||
cdef public object lex_attr_getters
|
||||
cdef public object cfg
|
||||
|
||||
cdef const LexemeC* get(self, str string) except NULL
|
||||
|
|
|
@ -5,6 +5,7 @@ import numpy
|
|||
import srsly
|
||||
from thinc.api import get_array_module, get_current_ops
|
||||
import functools
|
||||
import inspect
|
||||
|
||||
from .lexeme cimport EMPTY_LEXEME, OOV_RANK
|
||||
from .lexeme cimport Lexeme
|
||||
|
@ -14,29 +15,26 @@ from .attrs cimport LANG, ORTH
|
|||
|
||||
from .compat import copy_reg
|
||||
from .errors import Errors
|
||||
from .attrs import intify_attrs, NORM, IS_STOP
|
||||
from .attrs import intify_attrs, NORM
|
||||
from .vectors import Vectors, Mode as VectorsMode
|
||||
from .util import registry
|
||||
from .lookups import Lookups
|
||||
from . import util
|
||||
from .lang.norm_exceptions import BASE_NORMS
|
||||
from .lang.lex_attrs import LEX_ATTRS, is_stop, get_lang
|
||||
from .lang.lex_attrs import LEX_ATTRS
|
||||
|
||||
|
||||
def create_vocab(lang, defaults, vectors_name=None):
|
||||
# If the spacy-lookups-data package is installed, we pre-populate the lookups
|
||||
# with lexeme data, if available
|
||||
lex_attrs = {**LEX_ATTRS, **defaults.lex_attr_getters}
|
||||
# This is messy, but it's the minimal working fix to Issue #639.
|
||||
lex_attrs[IS_STOP] = functools.partial(is_stop, stops=defaults.stop_words)
|
||||
# Ensure that getter can be pickled
|
||||
lex_attrs[LANG] = functools.partial(get_lang, lang=lang)
|
||||
lex_attrs[NORM] = util.add_lookups(
|
||||
lex_attrs.get(NORM, LEX_ATTRS[NORM]),
|
||||
BASE_NORMS,
|
||||
)
|
||||
defaults.lex_attr_data["stops"] = list(defaults.stop_words)
|
||||
lookups = Lookups()
|
||||
lookups.add_table("lexeme_norm", BASE_NORMS)
|
||||
return Vocab(
|
||||
lang=lang,
|
||||
lex_attr_getters=lex_attrs,
|
||||
lex_attr_data=defaults.lex_attr_data,
|
||||
writing_system=defaults.writing_system,
|
||||
get_noun_chunks=defaults.syntax_iterators.get("noun_chunks"),
|
||||
vectors_name=vectors_name,
|
||||
|
@ -52,7 +50,8 @@ cdef class Vocab:
|
|||
"""
|
||||
def __init__(self, lex_attr_getters=None, strings=tuple(), lookups=None,
|
||||
oov_prob=-20., vectors_name=None, writing_system={},
|
||||
get_noun_chunks=None, **deprecated_kwargs):
|
||||
get_noun_chunks=None, lang="", lex_attr_data=None,
|
||||
**deprecated_kwargs):
|
||||
"""Create the vocabulary.
|
||||
|
||||
lex_attr_getters (dict): A dictionary mapping attribute IDs to
|
||||
|
@ -66,9 +65,11 @@ cdef class Vocab:
|
|||
A function that yields base noun phrases used for Doc.noun_chunks.
|
||||
"""
|
||||
lex_attr_getters = lex_attr_getters if lex_attr_getters is not None else {}
|
||||
lex_attr_data = lex_attr_data if lex_attr_data is not None else {}
|
||||
if lookups in (None, True, False):
|
||||
lookups = Lookups()
|
||||
self.cfg = {'oov_prob': oov_prob}
|
||||
self.lang = lang
|
||||
self.mem = Pool()
|
||||
self._by_orth = PreshMap()
|
||||
self.strings = StringStore()
|
||||
|
@ -77,6 +78,7 @@ cdef class Vocab:
|
|||
for string in strings:
|
||||
_ = self[string]
|
||||
self.lex_attr_getters = lex_attr_getters
|
||||
self.lex_attr_data = lex_attr_data
|
||||
self.morphology = Morphology(self.strings)
|
||||
self.vectors = Vectors(strings=self.strings, name=vectors_name)
|
||||
self.lookups = lookups
|
||||
|
@ -93,13 +95,6 @@ cdef class Vocab:
|
|||
self._vectors = vectors
|
||||
self._vectors.strings = self.strings
|
||||
|
||||
@property
|
||||
def lang(self):
|
||||
langfunc = None
|
||||
if self.lex_attr_getters:
|
||||
langfunc = self.lex_attr_getters.get(LANG, None)
|
||||
return langfunc("_") if langfunc else ""
|
||||
|
||||
def __len__(self):
|
||||
"""The current number of lexemes stored.
|
||||
|
||||
|
@ -183,7 +178,7 @@ cdef class Vocab:
|
|||
lex.id = OOV_RANK
|
||||
if self.lex_attr_getters is not None:
|
||||
for attr, func in self.lex_attr_getters.items():
|
||||
value = func(string)
|
||||
value = _get_lex_attr_value(self, func, string)
|
||||
if isinstance(value, str):
|
||||
value = self.strings.add(value)
|
||||
if value is not None:
|
||||
|
@ -433,12 +428,23 @@ cdef class Vocab:
|
|||
|
||||
def __set__(self, lookups):
|
||||
self._lookups = lookups
|
||||
if lookups.has_table("lexeme_norm"):
|
||||
self.lex_attr_getters[NORM] = util.add_lookups(
|
||||
self.lex_attr_getters.get(NORM, LEX_ATTRS[NORM]),
|
||||
self.lookups.get_table("lexeme_norm"),
|
||||
)
|
||||
self._reset_lexeme_cache()
|
||||
|
||||
property lex_attr_getters:
|
||||
def __get__(self):
|
||||
return self._lex_attr_getters
|
||||
|
||||
def __set__(self, lex_attr_getters):
|
||||
self._lex_attr_getters = lex_attr_getters
|
||||
self._reset_lexeme_cache()
|
||||
|
||||
property lex_attr_data:
|
||||
def __get__(self):
|
||||
return self._lex_attr_data
|
||||
|
||||
def __set__(self, lex_attr_data):
|
||||
self._lex_attr_data = lex_attr_data
|
||||
self._reset_lexeme_cache()
|
||||
|
||||
def to_disk(self, path, *, exclude=tuple()):
|
||||
"""Save the current state to a directory.
|
||||
|
@ -459,6 +465,8 @@ cdef class Vocab:
|
|||
self.vectors.to_disk(path, exclude=["strings"])
|
||||
if "lookups" not in exclude:
|
||||
self.lookups.to_disk(path)
|
||||
if "lex_attr_data" not in exclude:
|
||||
srsly.write_msgpack(path / "lex_attr_data", self.lex_attr_data)
|
||||
|
||||
def from_disk(self, path, *, exclude=tuple()):
|
||||
"""Loads state from a directory. Modifies the object in place and
|
||||
|
@ -471,7 +479,6 @@ cdef class Vocab:
|
|||
DOCS: https://spacy.io/api/vocab#to_disk
|
||||
"""
|
||||
path = util.ensure_path(path)
|
||||
getters = ["strings", "vectors"]
|
||||
if "strings" not in exclude:
|
||||
self.strings.from_disk(path / "strings.json") # TODO: add exclude?
|
||||
if "vectors" not in exclude:
|
||||
|
@ -479,12 +486,9 @@ cdef class Vocab:
|
|||
self.vectors.from_disk(path, exclude=["strings"])
|
||||
if "lookups" not in exclude:
|
||||
self.lookups.from_disk(path)
|
||||
if "lexeme_norm" in self.lookups:
|
||||
self.lex_attr_getters[NORM] = util.add_lookups(
|
||||
self.lex_attr_getters.get(NORM, LEX_ATTRS[NORM]), self.lookups.get_table("lexeme_norm")
|
||||
)
|
||||
self.length = 0
|
||||
self._by_orth = PreshMap()
|
||||
if "lex_attr_data" not in exclude:
|
||||
self.lex_attr_data = srsly.read_msgpack(path / "lex_attr_data")
|
||||
self._reset_lexeme_cache()
|
||||
return self
|
||||
|
||||
def to_bytes(self, *, exclude=tuple()):
|
||||
|
@ -505,6 +509,7 @@ cdef class Vocab:
|
|||
"strings": lambda: self.strings.to_bytes(),
|
||||
"vectors": deserialize_vectors,
|
||||
"lookups": lambda: self.lookups.to_bytes(),
|
||||
"lex_attr_data": lambda: srsly.msgpack_dumps(self.lex_attr_data)
|
||||
}
|
||||
return util.to_bytes(getters, exclude)
|
||||
|
||||
|
@ -523,24 +528,27 @@ cdef class Vocab:
|
|||
else:
|
||||
return self.vectors.from_bytes(b, exclude=["strings"])
|
||||
|
||||
def serialize_lex_attr_data(b):
|
||||
self.lex_attr_data = srsly.msgpack_loads(b)
|
||||
|
||||
setters = {
|
||||
"strings": lambda b: self.strings.from_bytes(b),
|
||||
"vectors": lambda b: serialize_vectors(b),
|
||||
"lookups": lambda b: self.lookups.from_bytes(b),
|
||||
"lex_attr_data": lambda b: serialize_lex_attr_data(b),
|
||||
}
|
||||
util.from_bytes(bytes_data, setters, exclude)
|
||||
if "lexeme_norm" in self.lookups:
|
||||
self.lex_attr_getters[NORM] = util.add_lookups(
|
||||
self.lex_attr_getters.get(NORM, LEX_ATTRS[NORM]), self.lookups.get_table("lexeme_norm")
|
||||
)
|
||||
self.length = 0
|
||||
self._by_orth = PreshMap()
|
||||
self._reset_lexeme_cache()
|
||||
return self
|
||||
|
||||
def _reset_cache(self, keys, strings):
|
||||
# I'm not sure this made sense. Disable it for now.
|
||||
raise NotImplementedError
|
||||
|
||||
def _reset_lexeme_cache(self):
|
||||
self.length = 0
|
||||
self._by_orth = PreshMap()
|
||||
|
||||
|
||||
def pickle_vocab(vocab):
|
||||
sstore = vocab.strings
|
||||
|
@ -565,3 +573,12 @@ def unpickle_vocab(sstore, vectors, morphology, lex_attr_getters, lookups, get_n
|
|||
|
||||
|
||||
copy_reg.pickle(Vocab, pickle_vocab, unpickle_vocab)
|
||||
|
||||
|
||||
def _get_lex_attr_value(vocab, func, string):
|
||||
if "vocab" in inspect.signature(func).parameters:
|
||||
value = func(vocab, string)
|
||||
else:
|
||||
# TODO: add deprecation warning
|
||||
value = func(string)
|
||||
return value
|
||||
|
|
Loading…
Reference in New Issue
Block a user