mirror of
https://github.com/explosion/spaCy.git
synced 2025-04-22 10:02:01 +03:00
Implement Vocab.lex_attr_data as a dataclass
This commit is contained in:
parent
d3f7dcb3e3
commit
66afd23175
|
@ -1,4 +1,5 @@
|
|||
from typing import Set
|
||||
from dataclasses import dataclass, field
|
||||
import unicodedata
|
||||
import re
|
||||
|
||||
|
@ -6,6 +7,32 @@ from .. import attrs
|
|||
from .tokenizer_exceptions import URL_MATCH
|
||||
|
||||
|
||||
@dataclass
|
||||
class LexAttrData:
|
||||
lower: dict = field(default_factory=dict)
|
||||
norm: dict = field(default_factory=dict)
|
||||
prefix: dict = field(default_factory=dict)
|
||||
suffix: dict = field(default_factory=dict)
|
||||
is_alpha: dict = field(default_factory=dict)
|
||||
is_digit: dict = field(default_factory=dict)
|
||||
is_lower: dict = field(default_factory=dict)
|
||||
is_space: dict = field(default_factory=dict)
|
||||
is_title: dict = field(default_factory=dict)
|
||||
is_upper: dict = field(default_factory=dict)
|
||||
is_stop: dict = field(default_factory=dict)
|
||||
like_email: dict = field(default_factory=dict)
|
||||
like_num: dict = field(default_factory=dict)
|
||||
is_punct: dict = field(default_factory=dict)
|
||||
is_ascii: dict = field(default_factory=dict)
|
||||
shape: dict = field(default_factory=dict)
|
||||
is_bracket: dict = field(default_factory=dict)
|
||||
is_quote: dict = field(default_factory=dict)
|
||||
is_left_punct: dict = field(default_factory=dict)
|
||||
is_right_punct: dict = field(default_factory=dict)
|
||||
is_currency: dict = field(default_factory=dict)
|
||||
like_url: dict = field(default_factory=dict)
|
||||
|
||||
|
||||
_like_email = re.compile(r"([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)").match
|
||||
_tlds = set(
|
||||
"com|org|edu|gov|net|mil|aero|asia|biz|cat|coop|info|int|jobs|mobi|museum|"
|
||||
|
@ -179,8 +206,8 @@ def is_upper(vocab, string: str) -> bool:
|
|||
|
||||
|
||||
def is_stop(vocab, string: str) -> bool:
|
||||
stops = vocab.lex_attr_data.get("stops", [])
|
||||
return string.lower() in stops
|
||||
stop_words = vocab.lex_attr_data.is_stop.get("stop_words", set())
|
||||
return string.lower() in stop_words
|
||||
|
||||
|
||||
def norm(vocab, string: str) -> str:
|
||||
|
|
|
@ -28,6 +28,7 @@ from .scorer import Scorer
|
|||
from .util import registry, SimpleFrozenList, _pipe, raise_error, _DEFAULT_EMPTY_PIPES
|
||||
from .util import SimpleFrozenDict, combine_score_weights, CONFIG_SECTION_ORDER
|
||||
from .util import warn_if_jupyter_cupy
|
||||
from .lang.lex_attrs import LexAttrData
|
||||
from .lang.tokenizer_exceptions import URL_MATCH, BASE_EXCEPTIONS
|
||||
from .lang.punctuation import TOKENIZER_PREFIXES, TOKENIZER_SUFFIXES
|
||||
from .lang.punctuation import TOKENIZER_INFIXES
|
||||
|
@ -74,7 +75,7 @@ class BaseDefaults:
|
|||
url_match: Optional[Callable] = URL_MATCH
|
||||
syntax_iterators: Dict[str, Callable] = {}
|
||||
lex_attr_getters: Dict[int, Callable[[str], Any]] = {}
|
||||
lex_attr_data: Dict[str, Any] = {}
|
||||
lex_attr_data: LexAttrData = LexAttrData()
|
||||
stop_words: Set[str] = set()
|
||||
writing_system = {"direction": "ltr", "has_case": True, "has_letters": True}
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@ from spacy.tokens import Doc
|
|||
from spacy.util import ensure_path, load_model
|
||||
from spacy.vectors import Vectors
|
||||
from spacy.vocab import Vocab
|
||||
from spacy.lang.lex_attrs import norm
|
||||
from spacy.lang.lex_attrs import norm, LexAttrData
|
||||
|
||||
from ..util import make_tempdir
|
||||
|
||||
|
@ -142,7 +142,12 @@ def test_deserialize_vocab_seen_entries(strings, lex_attr):
|
|||
|
||||
@pytest.mark.parametrize("strings,lex_attr", test_strings_attrs)
|
||||
def test_serialize_vocab_lex_attrs_disk(strings, lex_attr):
|
||||
vocab1 = Vocab(strings=strings, lex_attr_getters={NORM: norm})
|
||||
stop_words_data = {"stop_words": ["a", "b", "c"]}
|
||||
vocab1 = Vocab(
|
||||
strings=strings,
|
||||
lex_attr_getters={NORM: norm},
|
||||
lex_attr_data=LexAttrData(is_stop=stop_words_data),
|
||||
)
|
||||
vocab2 = Vocab(lex_attr_getters={NORM: norm})
|
||||
s = next(iter(vocab1.strings))
|
||||
vocab1[s].norm_ = lex_attr
|
||||
|
@ -153,6 +158,8 @@ def test_serialize_vocab_lex_attrs_disk(strings, lex_attr):
|
|||
vocab1.to_disk(file_path)
|
||||
vocab2 = vocab2.from_disk(file_path)
|
||||
assert vocab2[s].norm_ == lex_attr
|
||||
assert vocab1.lex_attr_data == vocab2.lex_attr_data
|
||||
assert vocab2.lex_attr_data.is_stop == stop_words_data
|
||||
|
||||
|
||||
@pytest.mark.parametrize("strings1,strings2", test_strings)
|
||||
|
@ -197,7 +204,7 @@ def test_pickle_vocab(strings, lex_attr):
|
|||
ops = get_current_ops()
|
||||
vectors = Vectors(data=ops.xp.zeros((10, 10)), mode="floret", hash_count=1)
|
||||
vocab.vectors = vectors
|
||||
vocab.lex_attr_data = {"a": 1}
|
||||
vocab.lex_attr_data.is_stop = {"stop_words": [1, 2, 3]}
|
||||
vocab[strings[0]].norm_ = lex_attr
|
||||
vocab_pickled = pickle.dumps(vocab)
|
||||
vocab_unpickled = pickle.loads(vocab_pickled)
|
||||
|
|
|
@ -3,6 +3,7 @@ from typing import Any, Iterable
|
|||
from thinc.types import Floats1d, FloatsXd
|
||||
from . import Language
|
||||
from .strings import StringStore
|
||||
from .lang.lex_attrs import LexAttrData
|
||||
from .lexeme import Lexeme
|
||||
from .lookups import Lookups
|
||||
from .morphology import Morphology
|
||||
|
@ -31,6 +32,7 @@ class Vocab:
|
|||
oov_prob: float = ...,
|
||||
writing_system: Dict[str, Any] = ...,
|
||||
get_noun_chunks: Optional[Callable[[Union[Doc, Span]], Iterator[Span]]] = ...,
|
||||
lex_attr_data=LexAttrData,
|
||||
) -> None: ...
|
||||
@property
|
||||
def lang(self) -> str: ...
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
# cython: profile=True
|
||||
from dataclasses import asdict
|
||||
|
||||
from libc.string cimport memcpy
|
||||
|
||||
import numpy
|
||||
|
@ -21,14 +23,15 @@ from .util import registry
|
|||
from .lookups import Lookups
|
||||
from . import util
|
||||
from .lang.norm_exceptions import BASE_NORMS
|
||||
from .lang.lex_attrs import LEX_ATTRS
|
||||
from .lang.lex_attrs import LEX_ATTRS, LexAttrData
|
||||
|
||||
|
||||
def create_vocab(lang, defaults):
|
||||
# If the spacy-lookups-data package is installed, we pre-populate the lookups
|
||||
# with lexeme data, if available
|
||||
lex_attrs = {**LEX_ATTRS, **defaults.lex_attr_getters}
|
||||
defaults.lex_attr_data["stops"] = list(defaults.stop_words)
|
||||
if len(defaults.stop_words) > 0:
|
||||
defaults.lex_attr_data.is_stop["stop_words"] = list(defaults.stop_words)
|
||||
lookups = Lookups()
|
||||
lookups.add_table("lexeme_norm", BASE_NORMS)
|
||||
return Vocab(
|
||||
|
@ -63,7 +66,7 @@ cdef class Vocab:
|
|||
A function that yields base noun phrases used for Doc.noun_chunks.
|
||||
"""
|
||||
lex_attr_getters = lex_attr_getters if lex_attr_getters is not None else {}
|
||||
lex_attr_data = lex_attr_data if lex_attr_data is not None else {}
|
||||
lex_attr_data = lex_attr_data if lex_attr_data is not None else LexAttrData()
|
||||
if lookups in (None, True, False):
|
||||
lookups = Lookups()
|
||||
self.cfg = {'oov_prob': oov_prob}
|
||||
|
@ -467,7 +470,7 @@ cdef class Vocab:
|
|||
if "lookups" not in exclude:
|
||||
self.lookups.to_disk(path)
|
||||
if "lex_attr_data" not in exclude:
|
||||
srsly.write_msgpack(path / "lex_attr_data", self.lex_attr_data)
|
||||
srsly.write_msgpack(path / "lex_attr_data", asdict(self.lex_attr_data))
|
||||
|
||||
def from_disk(self, path, *, exclude=tuple()):
|
||||
"""Loads state from a directory. Modifies the object in place and
|
||||
|
@ -488,7 +491,7 @@ cdef class Vocab:
|
|||
if "lookups" not in exclude:
|
||||
self.lookups.from_disk(path)
|
||||
if "lex_attr_data" not in exclude:
|
||||
self.lex_attr_data = srsly.read_msgpack(path / "lex_attr_data")
|
||||
self.lex_attr_data = LexAttrData(**srsly.read_msgpack(path / "lex_attr_data"))
|
||||
self._reset_lexeme_cache()
|
||||
return self
|
||||
|
||||
|
@ -510,7 +513,7 @@ cdef class Vocab:
|
|||
"strings": lambda: self.strings.to_bytes(),
|
||||
"vectors": deserialize_vectors,
|
||||
"lookups": lambda: self.lookups.to_bytes(),
|
||||
"lex_attr_data": lambda: srsly.msgpack_dumps(self.lex_attr_data)
|
||||
"lex_attr_data": lambda: srsly.msgpack_dumps(asdict(self.lex_attr_data))
|
||||
}
|
||||
return util.to_bytes(getters, exclude)
|
||||
|
||||
|
@ -530,7 +533,7 @@ cdef class Vocab:
|
|||
return self.vectors.from_bytes(b, exclude=["strings"])
|
||||
|
||||
def serialize_lex_attr_data(b):
|
||||
self.lex_attr_data = srsly.msgpack_loads(b)
|
||||
self.lex_attr_data = LexAttrData(**srsly.msgpack_loads(b))
|
||||
|
||||
setters = {
|
||||
"strings": lambda b: self.strings.from_bytes(b),
|
||||
|
|
Loading…
Reference in New Issue
Block a user