mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-14 18:22:27 +03:00
Implement Vocab.lex_attr_data as a dataclass
This commit is contained in:
parent
d3f7dcb3e3
commit
66afd23175
|
@ -1,4 +1,5 @@
|
||||||
from typing import Set
|
from typing import Set
|
||||||
|
from dataclasses import dataclass, field
|
||||||
import unicodedata
|
import unicodedata
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
@ -6,6 +7,32 @@ from .. import attrs
|
||||||
from .tokenizer_exceptions import URL_MATCH
|
from .tokenizer_exceptions import URL_MATCH
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LexAttrData:
|
||||||
|
lower: dict = field(default_factory=dict)
|
||||||
|
norm: dict = field(default_factory=dict)
|
||||||
|
prefix: dict = field(default_factory=dict)
|
||||||
|
suffix: dict = field(default_factory=dict)
|
||||||
|
is_alpha: dict = field(default_factory=dict)
|
||||||
|
is_digit: dict = field(default_factory=dict)
|
||||||
|
is_lower: dict = field(default_factory=dict)
|
||||||
|
is_space: dict = field(default_factory=dict)
|
||||||
|
is_title: dict = field(default_factory=dict)
|
||||||
|
is_upper: dict = field(default_factory=dict)
|
||||||
|
is_stop: dict = field(default_factory=dict)
|
||||||
|
like_email: dict = field(default_factory=dict)
|
||||||
|
like_num: dict = field(default_factory=dict)
|
||||||
|
is_punct: dict = field(default_factory=dict)
|
||||||
|
is_ascii: dict = field(default_factory=dict)
|
||||||
|
shape: dict = field(default_factory=dict)
|
||||||
|
is_bracket: dict = field(default_factory=dict)
|
||||||
|
is_quote: dict = field(default_factory=dict)
|
||||||
|
is_left_punct: dict = field(default_factory=dict)
|
||||||
|
is_right_punct: dict = field(default_factory=dict)
|
||||||
|
is_currency: dict = field(default_factory=dict)
|
||||||
|
like_url: dict = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
_like_email = re.compile(r"([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)").match
|
_like_email = re.compile(r"([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)").match
|
||||||
_tlds = set(
|
_tlds = set(
|
||||||
"com|org|edu|gov|net|mil|aero|asia|biz|cat|coop|info|int|jobs|mobi|museum|"
|
"com|org|edu|gov|net|mil|aero|asia|biz|cat|coop|info|int|jobs|mobi|museum|"
|
||||||
|
@ -179,8 +206,8 @@ def is_upper(vocab, string: str) -> bool:
|
||||||
|
|
||||||
|
|
||||||
def is_stop(vocab, string: str) -> bool:
|
def is_stop(vocab, string: str) -> bool:
|
||||||
stops = vocab.lex_attr_data.get("stops", [])
|
stop_words = vocab.lex_attr_data.is_stop.get("stop_words", set())
|
||||||
return string.lower() in stops
|
return string.lower() in stop_words
|
||||||
|
|
||||||
|
|
||||||
def norm(vocab, string: str) -> str:
|
def norm(vocab, string: str) -> str:
|
||||||
|
|
|
@ -28,6 +28,7 @@ from .scorer import Scorer
|
||||||
from .util import registry, SimpleFrozenList, _pipe, raise_error, _DEFAULT_EMPTY_PIPES
|
from .util import registry, SimpleFrozenList, _pipe, raise_error, _DEFAULT_EMPTY_PIPES
|
||||||
from .util import SimpleFrozenDict, combine_score_weights, CONFIG_SECTION_ORDER
|
from .util import SimpleFrozenDict, combine_score_weights, CONFIG_SECTION_ORDER
|
||||||
from .util import warn_if_jupyter_cupy
|
from .util import warn_if_jupyter_cupy
|
||||||
|
from .lang.lex_attrs import LexAttrData
|
||||||
from .lang.tokenizer_exceptions import URL_MATCH, BASE_EXCEPTIONS
|
from .lang.tokenizer_exceptions import URL_MATCH, BASE_EXCEPTIONS
|
||||||
from .lang.punctuation import TOKENIZER_PREFIXES, TOKENIZER_SUFFIXES
|
from .lang.punctuation import TOKENIZER_PREFIXES, TOKENIZER_SUFFIXES
|
||||||
from .lang.punctuation import TOKENIZER_INFIXES
|
from .lang.punctuation import TOKENIZER_INFIXES
|
||||||
|
@ -74,7 +75,7 @@ class BaseDefaults:
|
||||||
url_match: Optional[Callable] = URL_MATCH
|
url_match: Optional[Callable] = URL_MATCH
|
||||||
syntax_iterators: Dict[str, Callable] = {}
|
syntax_iterators: Dict[str, Callable] = {}
|
||||||
lex_attr_getters: Dict[int, Callable[[str], Any]] = {}
|
lex_attr_getters: Dict[int, Callable[[str], Any]] = {}
|
||||||
lex_attr_data: Dict[str, Any] = {}
|
lex_attr_data: LexAttrData = LexAttrData()
|
||||||
stop_words: Set[str] = set()
|
stop_words: Set[str] = set()
|
||||||
writing_system = {"direction": "ltr", "has_case": True, "has_letters": True}
|
writing_system = {"direction": "ltr", "has_case": True, "has_letters": True}
|
||||||
|
|
||||||
|
|
|
@ -11,7 +11,7 @@ from spacy.tokens import Doc
|
||||||
from spacy.util import ensure_path, load_model
|
from spacy.util import ensure_path, load_model
|
||||||
from spacy.vectors import Vectors
|
from spacy.vectors import Vectors
|
||||||
from spacy.vocab import Vocab
|
from spacy.vocab import Vocab
|
||||||
from spacy.lang.lex_attrs import norm
|
from spacy.lang.lex_attrs import norm, LexAttrData
|
||||||
|
|
||||||
from ..util import make_tempdir
|
from ..util import make_tempdir
|
||||||
|
|
||||||
|
@ -142,7 +142,12 @@ def test_deserialize_vocab_seen_entries(strings, lex_attr):
|
||||||
|
|
||||||
@pytest.mark.parametrize("strings,lex_attr", test_strings_attrs)
|
@pytest.mark.parametrize("strings,lex_attr", test_strings_attrs)
|
||||||
def test_serialize_vocab_lex_attrs_disk(strings, lex_attr):
|
def test_serialize_vocab_lex_attrs_disk(strings, lex_attr):
|
||||||
vocab1 = Vocab(strings=strings, lex_attr_getters={NORM: norm})
|
stop_words_data = {"stop_words": ["a", "b", "c"]}
|
||||||
|
vocab1 = Vocab(
|
||||||
|
strings=strings,
|
||||||
|
lex_attr_getters={NORM: norm},
|
||||||
|
lex_attr_data=LexAttrData(is_stop=stop_words_data),
|
||||||
|
)
|
||||||
vocab2 = Vocab(lex_attr_getters={NORM: norm})
|
vocab2 = Vocab(lex_attr_getters={NORM: norm})
|
||||||
s = next(iter(vocab1.strings))
|
s = next(iter(vocab1.strings))
|
||||||
vocab1[s].norm_ = lex_attr
|
vocab1[s].norm_ = lex_attr
|
||||||
|
@ -153,6 +158,8 @@ def test_serialize_vocab_lex_attrs_disk(strings, lex_attr):
|
||||||
vocab1.to_disk(file_path)
|
vocab1.to_disk(file_path)
|
||||||
vocab2 = vocab2.from_disk(file_path)
|
vocab2 = vocab2.from_disk(file_path)
|
||||||
assert vocab2[s].norm_ == lex_attr
|
assert vocab2[s].norm_ == lex_attr
|
||||||
|
assert vocab1.lex_attr_data == vocab2.lex_attr_data
|
||||||
|
assert vocab2.lex_attr_data.is_stop == stop_words_data
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("strings1,strings2", test_strings)
|
@pytest.mark.parametrize("strings1,strings2", test_strings)
|
||||||
|
@ -197,7 +204,7 @@ def test_pickle_vocab(strings, lex_attr):
|
||||||
ops = get_current_ops()
|
ops = get_current_ops()
|
||||||
vectors = Vectors(data=ops.xp.zeros((10, 10)), mode="floret", hash_count=1)
|
vectors = Vectors(data=ops.xp.zeros((10, 10)), mode="floret", hash_count=1)
|
||||||
vocab.vectors = vectors
|
vocab.vectors = vectors
|
||||||
vocab.lex_attr_data = {"a": 1}
|
vocab.lex_attr_data.is_stop = {"stop_words": [1, 2, 3]}
|
||||||
vocab[strings[0]].norm_ = lex_attr
|
vocab[strings[0]].norm_ = lex_attr
|
||||||
vocab_pickled = pickle.dumps(vocab)
|
vocab_pickled = pickle.dumps(vocab)
|
||||||
vocab_unpickled = pickle.loads(vocab_pickled)
|
vocab_unpickled = pickle.loads(vocab_pickled)
|
||||||
|
|
|
@ -3,6 +3,7 @@ from typing import Any, Iterable
|
||||||
from thinc.types import Floats1d, FloatsXd
|
from thinc.types import Floats1d, FloatsXd
|
||||||
from . import Language
|
from . import Language
|
||||||
from .strings import StringStore
|
from .strings import StringStore
|
||||||
|
from .lang.lex_attrs import LexAttrData
|
||||||
from .lexeme import Lexeme
|
from .lexeme import Lexeme
|
||||||
from .lookups import Lookups
|
from .lookups import Lookups
|
||||||
from .morphology import Morphology
|
from .morphology import Morphology
|
||||||
|
@ -31,6 +32,7 @@ class Vocab:
|
||||||
oov_prob: float = ...,
|
oov_prob: float = ...,
|
||||||
writing_system: Dict[str, Any] = ...,
|
writing_system: Dict[str, Any] = ...,
|
||||||
get_noun_chunks: Optional[Callable[[Union[Doc, Span]], Iterator[Span]]] = ...,
|
get_noun_chunks: Optional[Callable[[Union[Doc, Span]], Iterator[Span]]] = ...,
|
||||||
|
lex_attr_data=LexAttrData,
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
@property
|
@property
|
||||||
def lang(self) -> str: ...
|
def lang(self) -> str: ...
|
||||||
|
|
|
@ -1,4 +1,6 @@
|
||||||
# cython: profile=True
|
# cython: profile=True
|
||||||
|
from dataclasses import asdict
|
||||||
|
|
||||||
from libc.string cimport memcpy
|
from libc.string cimport memcpy
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
|
@ -21,14 +23,15 @@ from .util import registry
|
||||||
from .lookups import Lookups
|
from .lookups import Lookups
|
||||||
from . import util
|
from . import util
|
||||||
from .lang.norm_exceptions import BASE_NORMS
|
from .lang.norm_exceptions import BASE_NORMS
|
||||||
from .lang.lex_attrs import LEX_ATTRS
|
from .lang.lex_attrs import LEX_ATTRS, LexAttrData
|
||||||
|
|
||||||
|
|
||||||
def create_vocab(lang, defaults):
|
def create_vocab(lang, defaults):
|
||||||
# If the spacy-lookups-data package is installed, we pre-populate the lookups
|
# If the spacy-lookups-data package is installed, we pre-populate the lookups
|
||||||
# with lexeme data, if available
|
# with lexeme data, if available
|
||||||
lex_attrs = {**LEX_ATTRS, **defaults.lex_attr_getters}
|
lex_attrs = {**LEX_ATTRS, **defaults.lex_attr_getters}
|
||||||
defaults.lex_attr_data["stops"] = list(defaults.stop_words)
|
if len(defaults.stop_words) > 0:
|
||||||
|
defaults.lex_attr_data.is_stop["stop_words"] = list(defaults.stop_words)
|
||||||
lookups = Lookups()
|
lookups = Lookups()
|
||||||
lookups.add_table("lexeme_norm", BASE_NORMS)
|
lookups.add_table("lexeme_norm", BASE_NORMS)
|
||||||
return Vocab(
|
return Vocab(
|
||||||
|
@ -63,7 +66,7 @@ cdef class Vocab:
|
||||||
A function that yields base noun phrases used for Doc.noun_chunks.
|
A function that yields base noun phrases used for Doc.noun_chunks.
|
||||||
"""
|
"""
|
||||||
lex_attr_getters = lex_attr_getters if lex_attr_getters is not None else {}
|
lex_attr_getters = lex_attr_getters if lex_attr_getters is not None else {}
|
||||||
lex_attr_data = lex_attr_data if lex_attr_data is not None else {}
|
lex_attr_data = lex_attr_data if lex_attr_data is not None else LexAttrData()
|
||||||
if lookups in (None, True, False):
|
if lookups in (None, True, False):
|
||||||
lookups = Lookups()
|
lookups = Lookups()
|
||||||
self.cfg = {'oov_prob': oov_prob}
|
self.cfg = {'oov_prob': oov_prob}
|
||||||
|
@ -467,7 +470,7 @@ cdef class Vocab:
|
||||||
if "lookups" not in exclude:
|
if "lookups" not in exclude:
|
||||||
self.lookups.to_disk(path)
|
self.lookups.to_disk(path)
|
||||||
if "lex_attr_data" not in exclude:
|
if "lex_attr_data" not in exclude:
|
||||||
srsly.write_msgpack(path / "lex_attr_data", self.lex_attr_data)
|
srsly.write_msgpack(path / "lex_attr_data", asdict(self.lex_attr_data))
|
||||||
|
|
||||||
def from_disk(self, path, *, exclude=tuple()):
|
def from_disk(self, path, *, exclude=tuple()):
|
||||||
"""Loads state from a directory. Modifies the object in place and
|
"""Loads state from a directory. Modifies the object in place and
|
||||||
|
@ -488,7 +491,7 @@ cdef class Vocab:
|
||||||
if "lookups" not in exclude:
|
if "lookups" not in exclude:
|
||||||
self.lookups.from_disk(path)
|
self.lookups.from_disk(path)
|
||||||
if "lex_attr_data" not in exclude:
|
if "lex_attr_data" not in exclude:
|
||||||
self.lex_attr_data = srsly.read_msgpack(path / "lex_attr_data")
|
self.lex_attr_data = LexAttrData(**srsly.read_msgpack(path / "lex_attr_data"))
|
||||||
self._reset_lexeme_cache()
|
self._reset_lexeme_cache()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@ -510,7 +513,7 @@ cdef class Vocab:
|
||||||
"strings": lambda: self.strings.to_bytes(),
|
"strings": lambda: self.strings.to_bytes(),
|
||||||
"vectors": deserialize_vectors,
|
"vectors": deserialize_vectors,
|
||||||
"lookups": lambda: self.lookups.to_bytes(),
|
"lookups": lambda: self.lookups.to_bytes(),
|
||||||
"lex_attr_data": lambda: srsly.msgpack_dumps(self.lex_attr_data)
|
"lex_attr_data": lambda: srsly.msgpack_dumps(asdict(self.lex_attr_data))
|
||||||
}
|
}
|
||||||
return util.to_bytes(getters, exclude)
|
return util.to_bytes(getters, exclude)
|
||||||
|
|
||||||
|
@ -530,7 +533,7 @@ cdef class Vocab:
|
||||||
return self.vectors.from_bytes(b, exclude=["strings"])
|
return self.vectors.from_bytes(b, exclude=["strings"])
|
||||||
|
|
||||||
def serialize_lex_attr_data(b):
|
def serialize_lex_attr_data(b):
|
||||||
self.lex_attr_data = srsly.msgpack_loads(b)
|
self.lex_attr_data = LexAttrData(**srsly.msgpack_loads(b))
|
||||||
|
|
||||||
setters = {
|
setters = {
|
||||||
"strings": lambda b: self.strings.from_bytes(b),
|
"strings": lambda b: self.strings.from_bytes(b),
|
||||||
|
|
Loading…
Reference in New Issue
Block a user