Clean up Vocab constructor (#12290)

* Clean up Vocab constructor

* Change effective type of `strings` from `Iterable[str]` to `Optional[StringStore]`
  * Don't automatically add strings to vocab
* Change default values to `None`
* Remove `**deprecated_kwargs`

* Format
This commit is contained in:
Adriane Boyd 2023-03-19 23:41:20 +01:00 committed by GitHub
parent 520279ff7c
commit 6ae7618418
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 34 additions and 25 deletions

View File

@ -2,7 +2,7 @@ from typing import List, Optional, Iterable, Iterator, Union, Any, Tuple, overlo
from pathlib import Path from pathlib import Path
class StringStore: class StringStore:
def __init__(self, strings: Optional[Iterable[str]]) -> None: ... def __init__(self, strings: Optional[Iterable[str]] = None) -> None: ...
@overload @overload
def __getitem__(self, string_or_hash: str) -> int: ... def __getitem__(self, string_or_hash: str) -> int: ...
@overload @overload

View File

@ -9,6 +9,7 @@ from spacy.lang.en import English
from spacy.lang.en.syntax_iterators import noun_chunks from spacy.lang.en.syntax_iterators import noun_chunks
from spacy.language import Language from spacy.language import Language
from spacy.pipeline import TrainablePipe from spacy.pipeline import TrainablePipe
from spacy.strings import StringStore
from spacy.tokens import Doc from spacy.tokens import Doc
from spacy.training import Example from spacy.training import Example
from spacy.util import SimpleFrozenList, get_arg_names, make_tempdir from spacy.util import SimpleFrozenList, get_arg_names, make_tempdir
@ -131,7 +132,7 @@ def test_issue5458():
# Test that the noun chuncker does not generate overlapping spans # Test that the noun chuncker does not generate overlapping spans
# fmt: off # fmt: off
words = ["In", "an", "era", "where", "markets", "have", "brought", "prosperity", "and", "empowerment", "."] words = ["In", "an", "era", "where", "markets", "have", "brought", "prosperity", "and", "empowerment", "."]
vocab = Vocab(strings=words) vocab = Vocab(strings=StringStore(words))
deps = ["ROOT", "det", "pobj", "advmod", "nsubj", "aux", "relcl", "dobj", "cc", "conj", "punct"] deps = ["ROOT", "det", "pobj", "advmod", "nsubj", "aux", "relcl", "dobj", "cc", "conj", "punct"]
pos = ["ADP", "DET", "NOUN", "ADV", "NOUN", "AUX", "VERB", "NOUN", "CCONJ", "NOUN", "PUNCT"] pos = ["ADP", "DET", "NOUN", "ADV", "NOUN", "AUX", "VERB", "NOUN", "CCONJ", "NOUN", "PUNCT"]
heads = [0, 2, 0, 9, 6, 6, 2, 6, 7, 7, 0] heads = [0, 2, 0, 9, 6, 6, 2, 6, 7, 7, 0]

View File

@ -13,8 +13,11 @@ from spacy.vocab import Vocab
from ..util import make_tempdir from ..util import make_tempdir
test_strings = [([], []), (["rats", "are", "cute"], ["i", "like", "rats"])] test_strings = [
test_strings_attrs = [(["rats", "are", "cute"], "Hello")] (StringStore(), StringStore()),
(StringStore(["rats", "are", "cute"]), StringStore(["i", "like", "rats"])),
]
test_strings_attrs = [(StringStore(["rats", "are", "cute"]), "Hello")]
@pytest.mark.issue(599) @pytest.mark.issue(599)
@ -81,7 +84,7 @@ def test_serialize_vocab_roundtrip_bytes(strings1, strings2):
vocab2 = Vocab(strings=strings2) vocab2 = Vocab(strings=strings2)
vocab1_b = vocab1.to_bytes() vocab1_b = vocab1.to_bytes()
vocab2_b = vocab2.to_bytes() vocab2_b = vocab2.to_bytes()
if strings1 == strings2: if strings1.to_bytes() == strings2.to_bytes():
assert vocab1_b == vocab2_b assert vocab1_b == vocab2_b
else: else:
assert vocab1_b != vocab2_b assert vocab1_b != vocab2_b
@ -117,11 +120,12 @@ def test_serialize_vocab_roundtrip_disk(strings1, strings2):
def test_serialize_vocab_lex_attrs_bytes(strings, lex_attr): def test_serialize_vocab_lex_attrs_bytes(strings, lex_attr):
vocab1 = Vocab(strings=strings) vocab1 = Vocab(strings=strings)
vocab2 = Vocab() vocab2 = Vocab()
vocab1[strings[0]].norm_ = lex_attr s = next(iter(vocab1.strings))
assert vocab1[strings[0]].norm_ == lex_attr vocab1[s].norm_ = lex_attr
assert vocab2[strings[0]].norm_ != lex_attr assert vocab1[s].norm_ == lex_attr
assert vocab2[s].norm_ != lex_attr
vocab2 = vocab2.from_bytes(vocab1.to_bytes()) vocab2 = vocab2.from_bytes(vocab1.to_bytes())
assert vocab2[strings[0]].norm_ == lex_attr assert vocab2[s].norm_ == lex_attr
@pytest.mark.parametrize("strings,lex_attr", test_strings_attrs) @pytest.mark.parametrize("strings,lex_attr", test_strings_attrs)
@ -136,14 +140,15 @@ def test_deserialize_vocab_seen_entries(strings, lex_attr):
def test_serialize_vocab_lex_attrs_disk(strings, lex_attr): def test_serialize_vocab_lex_attrs_disk(strings, lex_attr):
vocab1 = Vocab(strings=strings) vocab1 = Vocab(strings=strings)
vocab2 = Vocab() vocab2 = Vocab()
vocab1[strings[0]].norm_ = lex_attr s = next(iter(vocab1.strings))
assert vocab1[strings[0]].norm_ == lex_attr vocab1[s].norm_ = lex_attr
assert vocab2[strings[0]].norm_ != lex_attr assert vocab1[s].norm_ == lex_attr
assert vocab2[s].norm_ != lex_attr
with make_tempdir() as d: with make_tempdir() as d:
file_path = d / "vocab" file_path = d / "vocab"
vocab1.to_disk(file_path) vocab1.to_disk(file_path)
vocab2 = vocab2.from_disk(file_path) vocab2 = vocab2.from_disk(file_path)
assert vocab2[strings[0]].norm_ == lex_attr assert vocab2[s].norm_ == lex_attr
@pytest.mark.parametrize("strings1,strings2", test_strings) @pytest.mark.parametrize("strings1,strings2", test_strings)

View File

@ -17,7 +17,7 @@ def test_issue361(en_vocab, text1, text2):
@pytest.mark.issue(600) @pytest.mark.issue(600)
def test_issue600(): def test_issue600():
vocab = Vocab(tag_map={"NN": {"pos": "NOUN"}}) vocab = Vocab()
doc = Doc(vocab, words=["hello"]) doc = Doc(vocab, words=["hello"])
doc[0].tag_ = "NN" doc[0].tag_ = "NN"

View File

@ -26,7 +26,7 @@ class Vocab:
def __init__( def __init__(
self, self,
lex_attr_getters: Optional[Dict[str, Callable[[str], Any]]] = ..., lex_attr_getters: Optional[Dict[str, Callable[[str], Any]]] = ...,
strings: Optional[Union[List[str], StringStore]] = ..., strings: Optional[StringStore] = ...,
lookups: Optional[Lookups] = ..., lookups: Optional[Lookups] = ...,
oov_prob: float = ..., oov_prob: float = ...,
writing_system: Dict[str, Any] = ..., writing_system: Dict[str, Any] = ...,

View File

@ -49,9 +49,8 @@ cdef class Vocab:
DOCS: https://spacy.io/api/vocab DOCS: https://spacy.io/api/vocab
""" """
def __init__(self, lex_attr_getters=None, strings=tuple(), lookups=None, def __init__(self, lex_attr_getters=None, strings=None, lookups=None,
oov_prob=-20., writing_system={}, get_noun_chunks=None, oov_prob=-20., writing_system=None, get_noun_chunks=None):
**deprecated_kwargs):
"""Create the vocabulary. """Create the vocabulary.
lex_attr_getters (dict): A dictionary mapping attribute IDs to lex_attr_getters (dict): A dictionary mapping attribute IDs to
@ -69,15 +68,18 @@ cdef class Vocab:
self.cfg = {'oov_prob': oov_prob} self.cfg = {'oov_prob': oov_prob}
self.mem = Pool() self.mem = Pool()
self._by_orth = PreshMap() self._by_orth = PreshMap()
self.strings = StringStore()
self.length = 0 self.length = 0
if strings: if strings is None:
for string in strings: self.strings = StringStore()
_ = self[string] else:
self.strings = strings
self.lex_attr_getters = lex_attr_getters self.lex_attr_getters = lex_attr_getters
self.morphology = Morphology(self.strings) self.morphology = Morphology(self.strings)
self.vectors = Vectors(strings=self.strings) self.vectors = Vectors(strings=self.strings)
self.lookups = lookups self.lookups = lookups
if writing_system is None:
self.writing_system = {}
else:
self.writing_system = writing_system self.writing_system = writing_system
self.get_noun_chunks = get_noun_chunks self.get_noun_chunks = get_noun_chunks

View File

@ -17,14 +17,15 @@ Create the vocabulary.
> #### Example > #### Example
> >
> ```python > ```python
> from spacy.strings import StringStore
> from spacy.vocab import Vocab > from spacy.vocab import Vocab
> vocab = Vocab(strings=["hello", "world"]) > vocab = Vocab(strings=StringStore(["hello", "world"]))
> ``` > ```
| Name | Description | | Name | Description |
| ------------------ | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | ------------------ | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `lex_attr_getters` | A dictionary mapping attribute IDs to functions to compute them. Defaults to `None`. ~~Optional[Dict[str, Callable[[str], Any]]]~~ | | `lex_attr_getters` | A dictionary mapping attribute IDs to functions to compute them. Defaults to `None`. ~~Optional[Dict[str, Callable[[str], Any]]]~~ |
| `strings` | A [`StringStore`](/api/stringstore) that maps strings to hash values, and vice versa, or a list of strings. ~~Union[List[str], StringStore]~~ | | `strings` | A [`StringStore`](/api/stringstore) that maps strings to hash values. ~~Optional[StringStore]~~ |
| `lookups` | A [`Lookups`](/api/lookups) that stores the `lexeme_norm` and other large lookup tables. Defaults to `None`. ~~Optional[Lookups]~~ | | `lookups` | A [`Lookups`](/api/lookups) that stores the `lexeme_norm` and other large lookup tables. Defaults to `None`. ~~Optional[Lookups]~~ |
| `oov_prob` | The default OOV probability. Defaults to `-20.0`. ~~float~~ | | `oov_prob` | The default OOV probability. Defaults to `-20.0`. ~~float~~ |
| `writing_system` | A dictionary describing the language's writing system. Typically provided by [`Language.Defaults`](/api/language#defaults). ~~Dict[str, Any]~~ | | `writing_system` | A dictionary describing the language's writing system. Typically provided by [`Language.Defaults`](/api/language#defaults). ~~Dict[str, Any]~~ |