Sketch of lex_attr_getters taking the vocab

This commit is contained in:
Adriane Boyd 2023-02-06 17:09:52 +01:00
parent 6920fb7baf
commit b9a91120b2
11 changed files with 112 additions and 125 deletions

View File

@ -23,21 +23,21 @@ _tlds = set(
)
def is_punct(text: str) -> bool:
def is_punct(vocab, text: str) -> bool:
for char in text:
if not unicodedata.category(char).startswith("P"):
return False
return True
def is_ascii(text: str) -> bool:
def is_ascii(vocab, text: str) -> bool:
for char in text:
if ord(char) >= 128:
return False
return True
def like_num(text: str) -> bool:
def like_num(vocab, text: str) -> bool:
if text.startswith(("+", "-", "±", "~")):
text = text[1:]
# can be overwritten by lang with list of number words
@ -51,31 +51,31 @@ def like_num(text: str) -> bool:
return False
def is_bracket(text: str) -> bool:
def is_bracket(vocab, text: str) -> bool:
brackets = ("(", ")", "[", "]", "{", "}", "<", ">")
return text in brackets
def is_quote(text: str) -> bool:
def is_quote(vocab, text: str) -> bool:
# fmt: off
quotes = ('"', "'", "`", "«", "»", "", "", "", "", "", "", "", "", "", "", "", "", "''", "``")
# fmt: on
return text in quotes
def is_left_punct(text: str) -> bool:
def is_left_punct(vocab, text: str) -> bool:
# fmt: off
left_punct = ("(", "[", "{", "<", '"', "'", "«", "", "", "", "", "", "", "", "", "``")
# fmt: on
return text in left_punct
def is_right_punct(text: str) -> bool:
def is_right_punct(vocab, text: str) -> bool:
right_punct = (")", "]", "}", ">", '"', "'", "»", "", "", "", "", "''")
return text in right_punct
def is_currency(text: str) -> bool:
def is_currency(vocab, text: str) -> bool:
# can be overwritten by lang with list of currency words, e.g. dollar, euro
for char in text:
if unicodedata.category(char) != "Sc":
@ -83,11 +83,11 @@ def is_currency(text: str) -> bool:
return True
def like_email(text: str) -> bool:
def like_email(vocab, text: str) -> bool:
return bool(_like_email(text))
def like_url(text: str) -> bool:
def like_url(vocab, text: str) -> bool:
# We're looking for things that function in text like URLs. So, valid URL
# or not, anything they say http:// is going to be good.
if text.startswith("http://") or text.startswith("https://"):
@ -115,7 +115,7 @@ def like_url(text: str) -> bool:
return False
def word_shape(text: str) -> str:
def word_shape(vocab, text: str) -> str:
if len(text) >= 100:
return "LONG"
shape = []
@ -142,55 +142,54 @@ def word_shape(text: str) -> str:
return "".join(shape)
def lower(string: str) -> str:
def lower(vocab, string: str) -> str:
return string.lower()
def prefix(string: str) -> str:
def prefix(vocab, string: str) -> str:
return string[0]
def suffix(string: str) -> str:
def suffix(vocab, string: str) -> str:
return string[-3:]
def is_alpha(string: str) -> bool:
def is_alpha(vocab, string: str) -> bool:
return string.isalpha()
def is_digit(string: str) -> bool:
def is_digit(vocab, string: str) -> bool:
return string.isdigit()
def is_lower(string: str) -> bool:
def is_lower(vocab, string: str) -> bool:
return string.islower()
def is_space(string: str) -> bool:
def is_space(vocab, string: str) -> bool:
return string.isspace()
def is_title(string: str) -> bool:
def is_title(vocab, string: str) -> bool:
return string.istitle()
def is_upper(string: str) -> bool:
def is_upper(vocab, string: str) -> bool:
return string.isupper()
def is_stop(string: str, stops: Set[str] = set()) -> bool:
def is_stop(vocab, string: str) -> bool:
stops = vocab.lex_attr_data.get("stops", [])
return string.lower() in stops
def get_lang(text: str, lang: str = "") -> str:
# This function is partially applied so lang code can be passed in
# automatically while still allowing pickling
return lang
def norm(vocab, string: str) -> str:
return vocab.lookups.get_table("lexeme_norm", {}).get(string, string.lower())
LEX_ATTRS = {
attrs.LOWER: lower,
attrs.NORM: lower,
attrs.NORM: norm,
attrs.PREFIX: prefix,
attrs.SUFFIX: suffix,
attrs.IS_ALPHA: is_alpha,

View File

@ -74,6 +74,7 @@ class BaseDefaults:
url_match: Optional[Callable] = URL_MATCH
syntax_iterators: Dict[str, Callable] = {}
lex_attr_getters: Dict[int, Callable[[str], Any]] = {}
lex_attr_data: Dict[str, Any] = {}
stop_words: Set[str] = set()
writing_system = {"direction": "ltr", "has_case": True, "has_letters": True}

View File

@ -2,7 +2,7 @@ from numpy cimport ndarray
from .typedefs cimport attr_t, hash_t, flags_t, len_t, tag_t
from .attrs cimport attr_id_t
from .attrs cimport ID, ORTH, LOWER, NORM, SHAPE, PREFIX, SUFFIX, LENGTH, LANG
from .attrs cimport ID, ORTH, LOWER, NORM, SHAPE, PREFIX, SUFFIX, LENGTH
from .structs cimport LexemeC
from .vocab cimport Vocab
@ -40,8 +40,6 @@ cdef class Lexeme:
lex.prefix = value
elif name == SUFFIX:
lex.suffix = value
elif name == LANG:
lex.lang = value
@staticmethod
cdef inline attr_t get_struct_attr(const LexemeC* lex, attr_id_t feat_name) nogil:
@ -66,8 +64,6 @@ cdef class Lexeme:
return lex.suffix
elif feat_name == LENGTH:
return lex.length
elif feat_name == LANG:
return lex.lang
else:
return 0

View File

@ -247,13 +247,9 @@ cdef class Lexeme:
cluster_table = self.vocab.lookups.get_table("lexeme_cluster", {})
cluster_table[self.c.orth] = x
property lang:
"""RETURNS (uint64): Language of the parent vocabulary."""
def __get__(self):
return self.c.lang
def __set__(self, attr_t x):
self.c.lang = x
@property
def lang(self):
return self.vocab.strings[self.vocab.lang]
property prob:
"""RETURNS (float): Smoothed log probability estimate of the lexeme's
@ -316,13 +312,9 @@ cdef class Lexeme:
def __set__(self, str x):
self.c.suffix = self.vocab.strings.add(x)
property lang_:
"""RETURNS (str): Language of the parent vocabulary."""
def __get__(self):
return self.vocab.strings[self.c.lang]
def __set__(self, str x):
self.c.lang = self.vocab.strings.add(x)
@property
def lang_(self):
return self.vocab.lang
property flags:
"""RETURNS (uint64): Container of the lexeme's binary flags."""

View File

@ -11,8 +11,6 @@ from .parts_of_speech cimport univ_pos_t
cdef struct LexemeC:
flags_t flags
attr_t lang
attr_t id
attr_t length

View File

@ -9,8 +9,8 @@ from spacy.lang.lex_attrs import like_url, word_shape
@pytest.mark.parametrize("word", ["the"])
@pytest.mark.issue(1889)
def test_issue1889(word):
assert is_stop(word, STOP_WORDS) == is_stop(word.upper(), STOP_WORDS)
def test_issue1889(en_vocab, word):
assert is_stop(en_vocab, word) == is_stop(en_vocab, word.upper())
@pytest.mark.parametrize("text", ["dog"])
@ -59,13 +59,13 @@ def test_attrs_ent_iob_intify():
@pytest.mark.parametrize("text,match", [(",", True), (" ", False), ("a", False)])
def test_lex_attrs_is_punct(text, match):
assert is_punct(text) == match
def test_lex_attrs_is_punct(en_vocab, text, match):
assert is_punct(en_vocab, text) == match
@pytest.mark.parametrize("text,match", [(",", True), ("£", False), ("", False)])
def test_lex_attrs_is_ascii(text, match):
assert is_ascii(text) == match
def test_lex_attrs_is_ascii(en_vocab, text, match):
assert is_ascii(en_vocab, text) == match
@pytest.mark.parametrize(
@ -82,8 +82,8 @@ def test_lex_attrs_is_ascii(text, match):
("dog", False),
],
)
def test_lex_attrs_is_currency(text, match):
assert is_currency(text) == match
def test_lex_attrs_is_currency(en_vocab, text, match):
assert is_currency(en_vocab, text) == match
@pytest.mark.parametrize(
@ -102,8 +102,8 @@ def test_lex_attrs_is_currency(text, match):
("hello.There", False),
],
)
def test_lex_attrs_like_url(text, match):
assert like_url(text) == match
def test_lex_attrs_like_url(en_vocab, text, match):
assert like_url(en_vocab, text) == match
@pytest.mark.parametrize(
@ -118,5 +118,5 @@ def test_lex_attrs_like_url(text, match):
("``,-", "``,-"),
],
)
def test_lex_attrs_word_shape(text, shape):
assert word_shape(text) == shape
def test_lex_attrs_word_shape(en_vocab, text, shape):
assert word_shape(en_vocab, text) == shape

View File

@ -6,10 +6,12 @@ from thinc.api import get_current_ops
import spacy
from spacy.lang.en import English
from spacy.strings import StringStore
from spacy.symbols import NORM
from spacy.tokens import Doc
from spacy.util import ensure_path, load_model
from spacy.vectors import Vectors
from spacy.vocab import Vocab
from spacy.lang.lex_attrs import norm
from ..util import make_tempdir
@ -73,6 +75,7 @@ def test_serialize_vocab(en_vocab, text):
new_vocab = Vocab().from_bytes(vocab_bytes)
assert new_vocab.strings[text_hash] == text
assert new_vocab.to_bytes(exclude=["lookups"]) == vocab_bytes
assert new_vocab.lex_attr_data == en_vocab.lex_attr_data
@pytest.mark.parametrize("strings1,strings2", test_strings)
@ -111,12 +114,13 @@ def test_serialize_vocab_roundtrip_disk(strings1, strings2):
assert [s for s in vocab1_d.strings] == [s for s in vocab2_d.strings]
else:
assert [s for s in vocab1_d.strings] != [s for s in vocab2_d.strings]
assert vocab1_d.lex_attr_data == vocab2_d.lex_attr_data
@pytest.mark.parametrize("strings,lex_attr", test_strings_attrs)
def test_serialize_vocab_lex_attrs_bytes(strings, lex_attr):
vocab1 = Vocab(strings=strings)
vocab2 = Vocab()
vocab1 = Vocab(strings=strings, lex_attr_getters={NORM: norm})
vocab2 = Vocab(lex_attr_getters={NORM: norm})
vocab1[strings[0]].norm_ = lex_attr
assert vocab1[strings[0]].norm_ == lex_attr
assert vocab2[strings[0]].norm_ != lex_attr
@ -134,8 +138,8 @@ def test_deserialize_vocab_seen_entries(strings, lex_attr):
@pytest.mark.parametrize("strings,lex_attr", test_strings_attrs)
def test_serialize_vocab_lex_attrs_disk(strings, lex_attr):
vocab1 = Vocab(strings=strings)
vocab2 = Vocab()
vocab1 = Vocab(strings=strings, lex_attr_getters={NORM: norm})
vocab2 = Vocab(lex_attr_getters={NORM: norm})
vocab1[strings[0]].norm_ = lex_attr
assert vocab1[strings[0]].norm_ == lex_attr
assert vocab2[strings[0]].norm_ != lex_attr

View File

@ -288,7 +288,7 @@ cdef class Token:
"""RETURNS (uint64): ID of the language of the parent document's
vocabulary.
"""
return self.c.lex.lang
return self.doc.lang
@property
def idx(self):
@ -844,7 +844,7 @@ cdef class Token:
"""RETURNS (str): Language of the parent document's vocabulary,
e.g. 'en'.
"""
return self.vocab.strings[self.c.lex.lang]
return self.doc.lang_
property lemma_:
"""RETURNS (str): The token lemma, i.e. the base form of the word,

View File

@ -1160,28 +1160,6 @@ def compile_infix_regex(entries: Iterable[Union[str, Pattern]]) -> Pattern:
return re.compile(expression)
def add_lookups(default_func: Callable[[str], Any], *lookups) -> Callable[[str], Any]:
"""Extend an attribute function with special cases. If a word is in the
lookups, the value is returned. Otherwise the previous function is used.
default_func (callable): The default function to execute.
*lookups (dict): Lookup dictionary mapping string to attribute value.
RETURNS (callable): Lexical attribute getter.
"""
# This is implemented as functools.partial instead of a closure, to allow
# pickle to work.
return functools.partial(_get_attr_unless_lookup, default_func, lookups)
def _get_attr_unless_lookup(
default_func: Callable[[str], Any], lookups: Dict[str, Any], string: str
) -> Any:
for lookup in lookups:
if string in lookup:
return lookup[string] # type: ignore[index]
return default_func(string)
def update_exc(
base_exceptions: Dict[str, List[dict]], *addition_dicts
) -> Dict[str, List[dict]]:

View File

@ -27,12 +27,14 @@ cdef class Vocab:
cdef Pool mem
cdef readonly StringStore strings
cdef public Morphology morphology
cdef public str lang
cdef public object _vectors
cdef public object _lookups
cdef public object _lex_attr_getters
cdef public object _lex_attr_data
cdef public object writing_system
cdef public object get_noun_chunks
cdef readonly int length
cdef public object lex_attr_getters
cdef public object cfg
cdef const LexemeC* get(self, str string) except NULL

View File

@ -5,6 +5,7 @@ import numpy
import srsly
from thinc.api import get_array_module, get_current_ops
import functools
import inspect
from .lexeme cimport EMPTY_LEXEME, OOV_RANK
from .lexeme cimport Lexeme
@ -14,29 +15,26 @@ from .attrs cimport LANG, ORTH
from .compat import copy_reg
from .errors import Errors
from .attrs import intify_attrs, NORM, IS_STOP
from .attrs import intify_attrs, NORM
from .vectors import Vectors, Mode as VectorsMode
from .util import registry
from .lookups import Lookups
from . import util
from .lang.norm_exceptions import BASE_NORMS
from .lang.lex_attrs import LEX_ATTRS, is_stop, get_lang
from .lang.lex_attrs import LEX_ATTRS
def create_vocab(lang, defaults, vectors_name=None):
# If the spacy-lookups-data package is installed, we pre-populate the lookups
# with lexeme data, if available
lex_attrs = {**LEX_ATTRS, **defaults.lex_attr_getters}
# This is messy, but it's the minimal working fix to Issue #639.
lex_attrs[IS_STOP] = functools.partial(is_stop, stops=defaults.stop_words)
# Ensure that getter can be pickled
lex_attrs[LANG] = functools.partial(get_lang, lang=lang)
lex_attrs[NORM] = util.add_lookups(
lex_attrs.get(NORM, LEX_ATTRS[NORM]),
BASE_NORMS,
)
defaults.lex_attr_data["stops"] = list(defaults.stop_words)
lookups = Lookups()
lookups.add_table("lexeme_norm", BASE_NORMS)
return Vocab(
lang=lang,
lex_attr_getters=lex_attrs,
lex_attr_data=defaults.lex_attr_data,
writing_system=defaults.writing_system,
get_noun_chunks=defaults.syntax_iterators.get("noun_chunks"),
vectors_name=vectors_name,
@ -52,7 +50,8 @@ cdef class Vocab:
"""
def __init__(self, lex_attr_getters=None, strings=tuple(), lookups=None,
oov_prob=-20., vectors_name=None, writing_system={},
get_noun_chunks=None, **deprecated_kwargs):
get_noun_chunks=None, lang="", lex_attr_data=None,
**deprecated_kwargs):
"""Create the vocabulary.
lex_attr_getters (dict): A dictionary mapping attribute IDs to
@ -66,9 +65,11 @@ cdef class Vocab:
A function that yields base noun phrases used for Doc.noun_chunks.
"""
lex_attr_getters = lex_attr_getters if lex_attr_getters is not None else {}
lex_attr_data = lex_attr_data if lex_attr_data is not None else {}
if lookups in (None, True, False):
lookups = Lookups()
self.cfg = {'oov_prob': oov_prob}
self.lang = lang
self.mem = Pool()
self._by_orth = PreshMap()
self.strings = StringStore()
@ -77,6 +78,7 @@ cdef class Vocab:
for string in strings:
_ = self[string]
self.lex_attr_getters = lex_attr_getters
self.lex_attr_data = lex_attr_data
self.morphology = Morphology(self.strings)
self.vectors = Vectors(strings=self.strings, name=vectors_name)
self.lookups = lookups
@ -93,13 +95,6 @@ cdef class Vocab:
self._vectors = vectors
self._vectors.strings = self.strings
@property
def lang(self):
langfunc = None
if self.lex_attr_getters:
langfunc = self.lex_attr_getters.get(LANG, None)
return langfunc("_") if langfunc else ""
def __len__(self):
"""The current number of lexemes stored.
@ -183,7 +178,7 @@ cdef class Vocab:
lex.id = OOV_RANK
if self.lex_attr_getters is not None:
for attr, func in self.lex_attr_getters.items():
value = func(string)
value = _get_lex_attr_value(self, func, string)
if isinstance(value, str):
value = self.strings.add(value)
if value is not None:
@ -433,12 +428,23 @@ cdef class Vocab:
def __set__(self, lookups):
self._lookups = lookups
if lookups.has_table("lexeme_norm"):
self.lex_attr_getters[NORM] = util.add_lookups(
self.lex_attr_getters.get(NORM, LEX_ATTRS[NORM]),
self.lookups.get_table("lexeme_norm"),
)
self._reset_lexeme_cache()
property lex_attr_getters:
def __get__(self):
return self._lex_attr_getters
def __set__(self, lex_attr_getters):
self._lex_attr_getters = lex_attr_getters
self._reset_lexeme_cache()
property lex_attr_data:
def __get__(self):
return self._lex_attr_data
def __set__(self, lex_attr_data):
self._lex_attr_data = lex_attr_data
self._reset_lexeme_cache()
def to_disk(self, path, *, exclude=tuple()):
"""Save the current state to a directory.
@ -459,6 +465,8 @@ cdef class Vocab:
self.vectors.to_disk(path, exclude=["strings"])
if "lookups" not in exclude:
self.lookups.to_disk(path)
if "lex_attr_data" not in exclude:
srsly.write_msgpack(path / "lex_attr_data", self.lex_attr_data)
def from_disk(self, path, *, exclude=tuple()):
"""Loads state from a directory. Modifies the object in place and
@ -471,7 +479,6 @@ cdef class Vocab:
DOCS: https://spacy.io/api/vocab#to_disk
"""
path = util.ensure_path(path)
getters = ["strings", "vectors"]
if "strings" not in exclude:
self.strings.from_disk(path / "strings.json") # TODO: add exclude?
if "vectors" not in exclude:
@ -479,12 +486,9 @@ cdef class Vocab:
self.vectors.from_disk(path, exclude=["strings"])
if "lookups" not in exclude:
self.lookups.from_disk(path)
if "lexeme_norm" in self.lookups:
self.lex_attr_getters[NORM] = util.add_lookups(
self.lex_attr_getters.get(NORM, LEX_ATTRS[NORM]), self.lookups.get_table("lexeme_norm")
)
self.length = 0
self._by_orth = PreshMap()
if "lex_attr_data" not in exclude:
self.lex_attr_data = srsly.read_msgpack(path / "lex_attr_data")
self._reset_lexeme_cache()
return self
def to_bytes(self, *, exclude=tuple()):
@ -505,6 +509,7 @@ cdef class Vocab:
"strings": lambda: self.strings.to_bytes(),
"vectors": deserialize_vectors,
"lookups": lambda: self.lookups.to_bytes(),
"lex_attr_data": lambda: srsly.msgpack_dumps(self.lex_attr_data)
}
return util.to_bytes(getters, exclude)
@ -523,24 +528,27 @@ cdef class Vocab:
else:
return self.vectors.from_bytes(b, exclude=["strings"])
def serialize_lex_attr_data(b):
self.lex_attr_data = srsly.msgpack_loads(b)
setters = {
"strings": lambda b: self.strings.from_bytes(b),
"vectors": lambda b: serialize_vectors(b),
"lookups": lambda b: self.lookups.from_bytes(b),
"lex_attr_data": lambda b: serialize_lex_attr_data(b),
}
util.from_bytes(bytes_data, setters, exclude)
if "lexeme_norm" in self.lookups:
self.lex_attr_getters[NORM] = util.add_lookups(
self.lex_attr_getters.get(NORM, LEX_ATTRS[NORM]), self.lookups.get_table("lexeme_norm")
)
self.length = 0
self._by_orth = PreshMap()
self._reset_lexeme_cache()
return self
def _reset_cache(self, keys, strings):
# I'm not sure this made sense. Disable it for now.
raise NotImplementedError
def _reset_lexeme_cache(self):
self.length = 0
self._by_orth = PreshMap()
def pickle_vocab(vocab):
sstore = vocab.strings
@ -565,3 +573,12 @@ def unpickle_vocab(sstore, vectors, morphology, lex_attr_getters, lookups, get_n
copy_reg.pickle(Vocab, pickle_vocab, unpickle_vocab)
def _get_lex_attr_value(vocab, func, string):
if "vocab" in inspect.signature(func).parameters:
value = func(vocab, string)
else:
# TODO: add deprecation warning
value = func(string)
return value