diff --git a/spacy/lexeme.pyx b/spacy/lexeme.pyx index e57098f17..e90505b6d 100644 --- a/spacy/lexeme.pyx +++ b/spacy/lexeme.pyx @@ -17,6 +17,7 @@ from .attrs cimport IS_CURRENCY from .attrs import intify_attrs from .errors import Errors, Warnings +from .util import DEFAULT_OOV_PROB OOV_RANK = 0xffffffffffffffff # UINT64_MAX @@ -244,7 +245,10 @@ cdef class Lexeme: return cluster_table.get(self.c.orth, 0) def __set__(self, int x): - cluster_table = self.vocab.lookups.get_table("lexeme_cluster", {}) + if "lexeme_cluster" in self.vocab.lookups: + cluster_table = self.vocab.lookups.get_table("lexeme_cluster") + else: + cluster_table = self.vocab.lookups.add_table("lexeme_cluster") cluster_table[self.c.orth] = x property lang: @@ -261,11 +265,14 @@ cdef class Lexeme: def __get__(self): prob_table = self.vocab.lookups.get_table("lexeme_prob", {}) settings_table = self.vocab.lookups.get_table("lexeme_settings", {}) - default_oov_prob = settings_table.get("oov_prob", -20.0) + default_oov_prob = settings_table.get("oov_prob", DEFAULT_OOV_PROB) return prob_table.get(self.c.orth, default_oov_prob) def __set__(self, float x): - prob_table = self.vocab.lookups.get_table("lexeme_prob", {}) + if "lexeme_prob" in self.vocab.lookups: + prob_table = self.vocab.lookups.get_table("lexeme_prob") + else: + prob_table = self.vocab.lookups.add_table("lexeme_prob") prob_table[self.c.orth] = x property lower_: diff --git a/spacy/tests/doc/test_token_api.py b/spacy/tests/doc/test_token_api.py index e715c5e85..fb87943a6 100644 --- a/spacy/tests/doc/test_token_api.py +++ b/spacy/tests/doc/test_token_api.py @@ -5,6 +5,7 @@ from spacy.symbols import VERB from spacy.vocab import Vocab from spacy.tokens import Doc from spacy.training import Example +from spacy.util import DEFAULT_OOV_PROB @pytest.fixture @@ -48,12 +49,16 @@ def test_doc_token_api_flags(en_tokenizer): # TODO: Test more of these, esp. if a bug is found -@pytest.mark.parametrize("text", ["Give it back! He pleaded."]) -def test_doc_token_api_prob_inherited_from_vocab(en_tokenizer, text): - word = text.split()[0] - en_tokenizer.vocab[word].prob = -1 - tokens = en_tokenizer(text) - assert tokens[0].prob != 0 +@pytest.mark.parametrize("words", [["a", "b"]]) +def test_doc_token_api_prob_inherited_from_vocab(words): + vocab = Vocab() + # setting a prob adds lexeme_prob and sets the value + vocab[words[0]].prob = -1 + assert vocab.lookups.get_table("lexeme_prob")[words[0]] == -1 + # vocab probs are reflected in tokens + tokens = Doc(vocab, words=words) + assert tokens[0].prob == -1 + assert tokens[1].prob == DEFAULT_OOV_PROB @pytest.mark.parametrize("text", ["one two"]) diff --git a/spacy/tests/vocab_vectors/test_lexeme.py b/spacy/tests/vocab_vectors/test_lexeme.py index d91f41db3..d0cbec991 100644 --- a/spacy/tests/vocab_vectors/test_lexeme.py +++ b/spacy/tests/vocab_vectors/test_lexeme.py @@ -3,7 +3,7 @@ import pytest from spacy.attrs import IS_ALPHA, IS_DIGIT from spacy.lookups import Lookups from spacy.tokens import Doc -from spacy.util import OOV_RANK +from spacy.util import OOV_RANK, DEFAULT_OOV_PROB from spacy.vocab import Vocab @@ -22,16 +22,21 @@ def test_issue600(): doc[0].tag_ = "NN" -@pytest.mark.parametrize("text1,prob1,text2,prob2", [("NOUN", -1, "opera", -2)]) -def test_vocab_lexeme_lt(en_vocab, text1, text2, prob1, prob2): - """More frequent is l.t. less frequent""" - lex1 = en_vocab[text1] - lex1.prob = prob1 - lex2 = en_vocab[text2] - lex2.prob = prob2 - - assert lex1 < lex2 - assert lex2 > lex1 +@pytest.mark.parametrize("word1,prob1,word2,prob2", [("NOUN", -1, "opera", -2)]) +def test_vocab_lexeme_prob(word1, word2, prob1, prob2): + """lexeme_prob table is only added if custom probs are set""" + vocab = Vocab() + # blank vocab does not include this table + assert "lexeme_prob" not in vocab.lookups.tables + # accessing a prob does not add a table + assert vocab["OTHER"].prob == DEFAULT_OOV_PROB + assert "lexeme_prob" not in vocab.lookups.tables + # table is added when a prob is set + vocab[word1].prob = -1 + assert "lexeme_prob" not in vocab.lookups.tables + assert word2 not in vocab.lookups.get_table("lexeme_prob") + vocab[word2].prob = prob2 + assert vocab[word1].prob > vocab[word2].prob @pytest.mark.parametrize("text1,text2", [("phantom", "opera")]) diff --git a/spacy/training/initialize.py b/spacy/training/initialize.py index 6304e4a84..2e0df044e 100644 --- a/spacy/training/initialize.py +++ b/spacy/training/initialize.py @@ -18,7 +18,7 @@ from ..errors import Errors, Warnings from ..schemas import ConfigSchemaTraining from ..util import registry, load_model_from_config, resolve_dot_names, logger from ..util import load_model, ensure_path, get_sourced_components -from ..util import OOV_RANK, DEFAULT_OOV_PROB +from ..util import OOV_RANK if TYPE_CHECKING: from ..language import Language # noqa: F401 @@ -120,11 +120,6 @@ def init_vocab( continue lexeme = nlp.vocab[attrs["orth"]] lexeme.set_attrs(**attrs) - if len(nlp.vocab): - oov_prob = min(lex.prob for lex in nlp.vocab) - 1 - else: - oov_prob = DEFAULT_OOV_PROB - nlp.vocab.cfg.update({"oov_prob": oov_prob}) logger.info(f"Added {len(nlp.vocab)} lexical entries to the vocab") logger.info("Created vocabulary") if vectors is not None: diff --git a/spacy/vocab.pxd b/spacy/vocab.pxd index 2db709b71..13fe99528 100644 --- a/spacy/vocab.pxd +++ b/spacy/vocab.pxd @@ -33,7 +33,6 @@ cdef class Vocab: cdef public object get_noun_chunks cdef readonly int length cdef public object lex_attr_getters - cdef public object cfg cdef const LexemeC* get(self, str string) except NULL cdef const LexemeC* get_by_orth(self, attr_t orth) except NULL diff --git a/spacy/vocab.pyi b/spacy/vocab.pyi index 41964703b..df3cc20c8 100644 --- a/spacy/vocab.pyi +++ b/spacy/vocab.pyi @@ -27,7 +27,6 @@ class Vocab: lex_attr_getters: Optional[Dict[str, Callable[[str], Any]]] = ..., strings: Optional[Union[List[str], StringStore]] = ..., lookups: Optional[Lookups] = ..., - oov_prob: float = ..., vectors_name: Optional[str] = ..., writing_system: Dict[str, Any] = ..., get_noun_chunks: Optional[Callable[[Union[Doc, Span]], Iterator[Span]]] = ..., diff --git a/spacy/vocab.pyx b/spacy/vocab.pyx index a87f50ad4..f052a4ec8 100644 --- a/spacy/vocab.pyx +++ b/spacy/vocab.pyx @@ -51,8 +51,8 @@ cdef class Vocab: DOCS: https://spacy.io/api/vocab """ def __init__(self, lex_attr_getters=None, strings=tuple(), lookups=None, - oov_prob=-20., vectors_name=None, writing_system={}, - get_noun_chunks=None, **deprecated_kwargs): + vectors_name=None, writing_system={}, get_noun_chunks=None, + **deprecated_kwargs): """Create the vocabulary. lex_attr_getters (dict): A dictionary mapping attribute IDs to @@ -60,7 +60,6 @@ cdef class Vocab: strings (StringStore): StringStore that maps strings to integers, and vice versa. lookups (Lookups): Container for large lookup tables and dictionaries. - oov_prob (float): Default OOV probability. vectors_name (str): Optional name to identify the vectors table. get_noun_chunks (Optional[Callable[[Union[Doc, Span], Iterator[Tuple[int, int, int]]]]]): A function that yields base noun phrases used for Doc.noun_chunks. @@ -68,7 +67,6 @@ cdef class Vocab: lex_attr_getters = lex_attr_getters if lex_attr_getters is not None else {} if lookups in (None, True, False): lookups = Lookups() - self.cfg = {'oov_prob': oov_prob} self.mem = Pool() self._by_orth = PreshMap() self.strings = StringStore() diff --git a/website/docs/api/vocab.mdx b/website/docs/api/vocab.mdx index 131e4ce0a..ee0785ff9 100644 --- a/website/docs/api/vocab.mdx +++ b/website/docs/api/vocab.mdx @@ -26,7 +26,6 @@ Create the vocabulary. | `lex_attr_getters` | A dictionary mapping attribute IDs to functions to compute them. Defaults to `None`. ~~Optional[Dict[str, Callable[[str], Any]]]~~ | | `strings` | A [`StringStore`](/api/stringstore) that maps strings to hash values, and vice versa, or a list of strings. ~~Union[List[str], StringStore]~~ | | `lookups` | A [`Lookups`](/api/lookups) that stores the `lexeme_norm` and other large lookup tables. Defaults to `None`. ~~Optional[Lookups]~~ | -| `oov_prob` | The default OOV probability. Defaults to `-20.0`. ~~float~~ | | `vectors_name` | A name to identify the vectors table. ~~str~~ | | `writing_system` | A dictionary describing the language's writing system. Typically provided by [`Language.Defaults`](/api/language#defaults). ~~Dict[str, Any]~~ | | `get_noun_chunks` | A function that yields base noun phrases used for [`Doc.noun_chunks`](/api/doc#noun_chunks). ~~Optional[Callable[[Union[Doc, Span], Iterator[Tuple[int, int, int]]]]]~~ |