Remove Vocab.oov_prob, fix lexeme_ lookups table creation

Remove `Vocab.oov_prob` and `Vocab.cfg` (both unused) and fix how
`lexeme_prob` and `lexeme_cluster` tables are added dynamically when an
attribute is set on a vocab where these tables don't already exist.
This commit is contained in:
Adriane Boyd 2023-02-07 12:46:11 +01:00
parent eec5ccd72f
commit aac8e1ba99
8 changed files with 40 additions and 33 deletions

View File

@ -17,6 +17,7 @@ from .attrs cimport IS_CURRENCY
from .attrs import intify_attrs
from .errors import Errors, Warnings
from .util import DEFAULT_OOV_PROB
OOV_RANK = 0xffffffffffffffff # UINT64_MAX
@ -244,7 +245,10 @@ cdef class Lexeme:
return cluster_table.get(self.c.orth, 0)
def __set__(self, int x):
cluster_table = self.vocab.lookups.get_table("lexeme_cluster", {})
if "lexeme_cluster" in self.vocab.lookups:
cluster_table = self.vocab.lookups.get_table("lexeme_cluster")
else:
cluster_table = self.vocab.lookups.add_table("lexeme_cluster")
cluster_table[self.c.orth] = x
property lang:
@ -261,11 +265,14 @@ cdef class Lexeme:
def __get__(self):
prob_table = self.vocab.lookups.get_table("lexeme_prob", {})
settings_table = self.vocab.lookups.get_table("lexeme_settings", {})
default_oov_prob = settings_table.get("oov_prob", -20.0)
default_oov_prob = settings_table.get("oov_prob", DEFAULT_OOV_PROB)
return prob_table.get(self.c.orth, default_oov_prob)
def __set__(self, float x):
prob_table = self.vocab.lookups.get_table("lexeme_prob", {})
if "lexeme_prob" in self.vocab.lookups:
prob_table = self.vocab.lookups.get_table("lexeme_prob")
else:
prob_table = self.vocab.lookups.add_table("lexeme_prob")
prob_table[self.c.orth] = x
property lower_:

View File

@ -5,6 +5,7 @@ from spacy.symbols import VERB
from spacy.vocab import Vocab
from spacy.tokens import Doc
from spacy.training import Example
from spacy.util import DEFAULT_OOV_PROB
@pytest.fixture
@ -48,12 +49,16 @@ def test_doc_token_api_flags(en_tokenizer):
# TODO: Test more of these, esp. if a bug is found
@pytest.mark.parametrize("text", ["Give it back! He pleaded."])
def test_doc_token_api_prob_inherited_from_vocab(en_tokenizer, text):
word = text.split()[0]
en_tokenizer.vocab[word].prob = -1
tokens = en_tokenizer(text)
assert tokens[0].prob != 0
@pytest.mark.parametrize("words", [["a", "b"]])
def test_doc_token_api_prob_inherited_from_vocab(words):
vocab = Vocab()
# setting a prob adds lexeme_prob and sets the value
vocab[words[0]].prob = -1
assert vocab.lookups.get_table("lexeme_prob")[words[0]] == -1
# vocab probs are reflected in tokens
tokens = Doc(vocab, words=words)
assert tokens[0].prob == -1
assert tokens[1].prob == DEFAULT_OOV_PROB
@pytest.mark.parametrize("text", ["one two"])

View File

@ -3,7 +3,7 @@ import pytest
from spacy.attrs import IS_ALPHA, IS_DIGIT
from spacy.lookups import Lookups
from spacy.tokens import Doc
from spacy.util import OOV_RANK
from spacy.util import OOV_RANK, DEFAULT_OOV_PROB
from spacy.vocab import Vocab
@ -22,16 +22,21 @@ def test_issue600():
doc[0].tag_ = "NN"
@pytest.mark.parametrize("text1,prob1,text2,prob2", [("NOUN", -1, "opera", -2)])
def test_vocab_lexeme_lt(en_vocab, text1, text2, prob1, prob2):
"""More frequent is l.t. less frequent"""
lex1 = en_vocab[text1]
lex1.prob = prob1
lex2 = en_vocab[text2]
lex2.prob = prob2
assert lex1 < lex2
assert lex2 > lex1
@pytest.mark.parametrize("word1,prob1,word2,prob2", [("NOUN", -1, "opera", -2)])
def test_vocab_lexeme_prob(word1, word2, prob1, prob2):
"""lexeme_prob table is only added if custom probs are set"""
vocab = Vocab()
# blank vocab does not include this table
assert "lexeme_prob" not in vocab.lookups.tables
# accessing a prob does not add a table
assert vocab["OTHER"].prob == DEFAULT_OOV_PROB
assert "lexeme_prob" not in vocab.lookups.tables
# table is added when a prob is set
vocab[word1].prob = -1
assert "lexeme_prob" not in vocab.lookups.tables
assert word2 not in vocab.lookups.get_table("lexeme_prob")
vocab[word2].prob = prob2
assert vocab[word1].prob > vocab[word2].prob
@pytest.mark.parametrize("text1,text2", [("phantom", "opera")])

View File

@ -18,7 +18,7 @@ from ..errors import Errors, Warnings
from ..schemas import ConfigSchemaTraining
from ..util import registry, load_model_from_config, resolve_dot_names, logger
from ..util import load_model, ensure_path, get_sourced_components
from ..util import OOV_RANK, DEFAULT_OOV_PROB
from ..util import OOV_RANK
if TYPE_CHECKING:
from ..language import Language # noqa: F401
@ -120,11 +120,6 @@ def init_vocab(
continue
lexeme = nlp.vocab[attrs["orth"]]
lexeme.set_attrs(**attrs)
if len(nlp.vocab):
oov_prob = min(lex.prob for lex in nlp.vocab) - 1
else:
oov_prob = DEFAULT_OOV_PROB
nlp.vocab.cfg.update({"oov_prob": oov_prob})
logger.info(f"Added {len(nlp.vocab)} lexical entries to the vocab")
logger.info("Created vocabulary")
if vectors is not None:

View File

@ -33,7 +33,6 @@ cdef class Vocab:
cdef public object get_noun_chunks
cdef readonly int length
cdef public object lex_attr_getters
cdef public object cfg
cdef const LexemeC* get(self, str string) except NULL
cdef const LexemeC* get_by_orth(self, attr_t orth) except NULL

View File

@ -27,7 +27,6 @@ class Vocab:
lex_attr_getters: Optional[Dict[str, Callable[[str], Any]]] = ...,
strings: Optional[Union[List[str], StringStore]] = ...,
lookups: Optional[Lookups] = ...,
oov_prob: float = ...,
vectors_name: Optional[str] = ...,
writing_system: Dict[str, Any] = ...,
get_noun_chunks: Optional[Callable[[Union[Doc, Span]], Iterator[Span]]] = ...,

View File

@ -51,8 +51,8 @@ cdef class Vocab:
DOCS: https://spacy.io/api/vocab
"""
def __init__(self, lex_attr_getters=None, strings=tuple(), lookups=None,
oov_prob=-20., vectors_name=None, writing_system={},
get_noun_chunks=None, **deprecated_kwargs):
vectors_name=None, writing_system={}, get_noun_chunks=None,
**deprecated_kwargs):
"""Create the vocabulary.
lex_attr_getters (dict): A dictionary mapping attribute IDs to
@ -60,7 +60,6 @@ cdef class Vocab:
strings (StringStore): StringStore that maps strings to integers, and
vice versa.
lookups (Lookups): Container for large lookup tables and dictionaries.
oov_prob (float): Default OOV probability.
vectors_name (str): Optional name to identify the vectors table.
get_noun_chunks (Optional[Callable[[Union[Doc, Span], Iterator[Tuple[int, int, int]]]]]):
A function that yields base noun phrases used for Doc.noun_chunks.
@ -68,7 +67,6 @@ cdef class Vocab:
lex_attr_getters = lex_attr_getters if lex_attr_getters is not None else {}
if lookups in (None, True, False):
lookups = Lookups()
self.cfg = {'oov_prob': oov_prob}
self.mem = Pool()
self._by_orth = PreshMap()
self.strings = StringStore()

View File

@ -26,7 +26,6 @@ Create the vocabulary.
| `lex_attr_getters` | A dictionary mapping attribute IDs to functions to compute them. Defaults to `None`. ~~Optional[Dict[str, Callable[[str], Any]]]~~ |
| `strings` | A [`StringStore`](/api/stringstore) that maps strings to hash values, and vice versa, or a list of strings. ~~Union[List[str], StringStore]~~ |
| `lookups` | A [`Lookups`](/api/lookups) that stores the `lexeme_norm` and other large lookup tables. Defaults to `None`. ~~Optional[Lookups]~~ |
| `oov_prob` | The default OOV probability. Defaults to `-20.0`. ~~float~~ |
| `vectors_name` | A name to identify the vectors table. ~~str~~ |
| `writing_system` | A dictionary describing the language's writing system. Typically provided by [`Language.Defaults`](/api/language#defaults). ~~Dict[str, Any]~~ |
| `get_noun_chunks` | A function that yields base noun phrases used for [`Doc.noun_chunks`](/api/doc#noun_chunks). ~~Optional[Callable[[Union[Doc, Span], Iterator[Tuple[int, int, int]]]]]~~ |