Implement Vocab.lex_attr_data as a dataclass

This commit is contained in:
Adriane Boyd 2023-03-30 12:37:39 +02:00
parent d3f7dcb3e3
commit 66afd23175
5 changed files with 53 additions and 13 deletions

View File

@ -1,4 +1,5 @@
from typing import Set
from dataclasses import dataclass, field
import unicodedata
import re
@ -6,6 +7,32 @@ from .. import attrs
from .tokenizer_exceptions import URL_MATCH
@dataclass
class LexAttrData:
lower: dict = field(default_factory=dict)
norm: dict = field(default_factory=dict)
prefix: dict = field(default_factory=dict)
suffix: dict = field(default_factory=dict)
is_alpha: dict = field(default_factory=dict)
is_digit: dict = field(default_factory=dict)
is_lower: dict = field(default_factory=dict)
is_space: dict = field(default_factory=dict)
is_title: dict = field(default_factory=dict)
is_upper: dict = field(default_factory=dict)
is_stop: dict = field(default_factory=dict)
like_email: dict = field(default_factory=dict)
like_num: dict = field(default_factory=dict)
is_punct: dict = field(default_factory=dict)
is_ascii: dict = field(default_factory=dict)
shape: dict = field(default_factory=dict)
is_bracket: dict = field(default_factory=dict)
is_quote: dict = field(default_factory=dict)
is_left_punct: dict = field(default_factory=dict)
is_right_punct: dict = field(default_factory=dict)
is_currency: dict = field(default_factory=dict)
like_url: dict = field(default_factory=dict)
_like_email = re.compile(r"([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)").match
_tlds = set(
"com|org|edu|gov|net|mil|aero|asia|biz|cat|coop|info|int|jobs|mobi|museum|"
@ -179,8 +206,8 @@ def is_upper(vocab, string: str) -> bool:
def is_stop(vocab, string: str) -> bool:
stops = vocab.lex_attr_data.get("stops", [])
return string.lower() in stops
stop_words = vocab.lex_attr_data.is_stop.get("stop_words", set())
return string.lower() in stop_words
def norm(vocab, string: str) -> str:

View File

@ -28,6 +28,7 @@ from .scorer import Scorer
from .util import registry, SimpleFrozenList, _pipe, raise_error, _DEFAULT_EMPTY_PIPES
from .util import SimpleFrozenDict, combine_score_weights, CONFIG_SECTION_ORDER
from .util import warn_if_jupyter_cupy
from .lang.lex_attrs import LexAttrData
from .lang.tokenizer_exceptions import URL_MATCH, BASE_EXCEPTIONS
from .lang.punctuation import TOKENIZER_PREFIXES, TOKENIZER_SUFFIXES
from .lang.punctuation import TOKENIZER_INFIXES
@ -74,7 +75,7 @@ class BaseDefaults:
url_match: Optional[Callable] = URL_MATCH
syntax_iterators: Dict[str, Callable] = {}
lex_attr_getters: Dict[int, Callable[[str], Any]] = {}
lex_attr_data: Dict[str, Any] = {}
lex_attr_data: LexAttrData = LexAttrData()
stop_words: Set[str] = set()
writing_system = {"direction": "ltr", "has_case": True, "has_letters": True}

View File

@ -11,7 +11,7 @@ from spacy.tokens import Doc
from spacy.util import ensure_path, load_model
from spacy.vectors import Vectors
from spacy.vocab import Vocab
from spacy.lang.lex_attrs import norm
from spacy.lang.lex_attrs import norm, LexAttrData
from ..util import make_tempdir
@ -142,7 +142,12 @@ def test_deserialize_vocab_seen_entries(strings, lex_attr):
@pytest.mark.parametrize("strings,lex_attr", test_strings_attrs)
def test_serialize_vocab_lex_attrs_disk(strings, lex_attr):
vocab1 = Vocab(strings=strings, lex_attr_getters={NORM: norm})
stop_words_data = {"stop_words": ["a", "b", "c"]}
vocab1 = Vocab(
strings=strings,
lex_attr_getters={NORM: norm},
lex_attr_data=LexAttrData(is_stop=stop_words_data),
)
vocab2 = Vocab(lex_attr_getters={NORM: norm})
s = next(iter(vocab1.strings))
vocab1[s].norm_ = lex_attr
@ -153,6 +158,8 @@ def test_serialize_vocab_lex_attrs_disk(strings, lex_attr):
vocab1.to_disk(file_path)
vocab2 = vocab2.from_disk(file_path)
assert vocab2[s].norm_ == lex_attr
assert vocab1.lex_attr_data == vocab2.lex_attr_data
assert vocab2.lex_attr_data.is_stop == stop_words_data
@pytest.mark.parametrize("strings1,strings2", test_strings)
@ -197,7 +204,7 @@ def test_pickle_vocab(strings, lex_attr):
ops = get_current_ops()
vectors = Vectors(data=ops.xp.zeros((10, 10)), mode="floret", hash_count=1)
vocab.vectors = vectors
vocab.lex_attr_data = {"a": 1}
vocab.lex_attr_data.is_stop = {"stop_words": [1, 2, 3]}
vocab[strings[0]].norm_ = lex_attr
vocab_pickled = pickle.dumps(vocab)
vocab_unpickled = pickle.loads(vocab_pickled)

View File

@ -3,6 +3,7 @@ from typing import Any, Iterable
from thinc.types import Floats1d, FloatsXd
from . import Language
from .strings import StringStore
from .lang.lex_attrs import LexAttrData
from .lexeme import Lexeme
from .lookups import Lookups
from .morphology import Morphology
@ -31,6 +32,7 @@ class Vocab:
oov_prob: float = ...,
writing_system: Dict[str, Any] = ...,
get_noun_chunks: Optional[Callable[[Union[Doc, Span]], Iterator[Span]]] = ...,
lex_attr_data=LexAttrData,
) -> None: ...
@property
def lang(self) -> str: ...

View File

@ -1,4 +1,6 @@
# cython: profile=True
from dataclasses import asdict
from libc.string cimport memcpy
import numpy
@ -21,14 +23,15 @@ from .util import registry
from .lookups import Lookups
from . import util
from .lang.norm_exceptions import BASE_NORMS
from .lang.lex_attrs import LEX_ATTRS
from .lang.lex_attrs import LEX_ATTRS, LexAttrData
def create_vocab(lang, defaults):
# If the spacy-lookups-data package is installed, we pre-populate the lookups
# with lexeme data, if available
lex_attrs = {**LEX_ATTRS, **defaults.lex_attr_getters}
defaults.lex_attr_data["stops"] = list(defaults.stop_words)
if len(defaults.stop_words) > 0:
defaults.lex_attr_data.is_stop["stop_words"] = list(defaults.stop_words)
lookups = Lookups()
lookups.add_table("lexeme_norm", BASE_NORMS)
return Vocab(
@ -63,7 +66,7 @@ cdef class Vocab:
A function that yields base noun phrases used for Doc.noun_chunks.
"""
lex_attr_getters = lex_attr_getters if lex_attr_getters is not None else {}
lex_attr_data = lex_attr_data if lex_attr_data is not None else {}
lex_attr_data = lex_attr_data if lex_attr_data is not None else LexAttrData()
if lookups in (None, True, False):
lookups = Lookups()
self.cfg = {'oov_prob': oov_prob}
@ -467,7 +470,7 @@ cdef class Vocab:
if "lookups" not in exclude:
self.lookups.to_disk(path)
if "lex_attr_data" not in exclude:
srsly.write_msgpack(path / "lex_attr_data", self.lex_attr_data)
srsly.write_msgpack(path / "lex_attr_data", asdict(self.lex_attr_data))
def from_disk(self, path, *, exclude=tuple()):
"""Loads state from a directory. Modifies the object in place and
@ -488,7 +491,7 @@ cdef class Vocab:
if "lookups" not in exclude:
self.lookups.from_disk(path)
if "lex_attr_data" not in exclude:
self.lex_attr_data = srsly.read_msgpack(path / "lex_attr_data")
self.lex_attr_data = LexAttrData(**srsly.read_msgpack(path / "lex_attr_data"))
self._reset_lexeme_cache()
return self
@ -510,7 +513,7 @@ cdef class Vocab:
"strings": lambda: self.strings.to_bytes(),
"vectors": deserialize_vectors,
"lookups": lambda: self.lookups.to_bytes(),
"lex_attr_data": lambda: srsly.msgpack_dumps(self.lex_attr_data)
"lex_attr_data": lambda: srsly.msgpack_dumps(asdict(self.lex_attr_data))
}
return util.to_bytes(getters, exclude)
@ -530,7 +533,7 @@ cdef class Vocab:
return self.vectors.from_bytes(b, exclude=["strings"])
def serialize_lex_attr_data(b):
self.lex_attr_data = srsly.msgpack_loads(b)
self.lex_attr_data = LexAttrData(**srsly.msgpack_loads(b))
setters = {
"strings": lambda b: self.strings.from_bytes(b),