mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 01:48:04 +03:00 
			
		
		
		
	* Make changes to typing * Correction * Format with black * Corrections based on review * Bumped Thinc dependency version * Bumped blis requirement * Correction for older Python versions * Update spacy/ml/models/textcat.py Co-authored-by: Daniël de Kok <me@github.danieldk.eu> * Corrections based on review feedback * Readd deleted docstring line Co-authored-by: Daniël de Kok <me@github.danieldk.eu>
		
			
				
	
	
		
			315 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			315 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
from typing import Any, List, Union, Optional, Dict
 | 
						|
from pathlib import Path
 | 
						|
import srsly
 | 
						|
from preshed.bloom import BloomFilter
 | 
						|
from collections import OrderedDict
 | 
						|
 | 
						|
from .errors import Errors
 | 
						|
from .util import SimpleFrozenDict, ensure_path, registry, load_language_data
 | 
						|
from .strings import get_string_id
 | 
						|
 | 
						|
 | 
						|
UNSET = object()
 | 
						|
 | 
						|
 | 
						|
def load_lookups(lang: str, tables: List[str], strict: bool = True) -> "Lookups":
 | 
						|
    """Load the data from the spacy-lookups-data package for a given language,
 | 
						|
    if available. Returns an empty `Lookups` container if there's no data or if the package
 | 
						|
    is not installed.
 | 
						|
 | 
						|
    lang (str): The language code (corresponds to entry point exposed by
 | 
						|
        the spacy-lookups-data package).
 | 
						|
    tables (List[str]): Name of tables to load, e.g. ["lemma_lookup", "lemma_exc"]
 | 
						|
    strict (bool): Whether to raise an error if a table doesn't exist.
 | 
						|
    RETURNS (Lookups): The lookups container containing the loaded tables.
 | 
						|
    """
 | 
						|
    # TODO: import spacy_lookups_data instead of going via entry points here?
 | 
						|
    lookups = Lookups()
 | 
						|
    if lang not in registry.lookups:
 | 
						|
        if strict and len(tables) > 0:
 | 
						|
            raise ValueError(Errors.E955.format(table=", ".join(tables), lang=lang))
 | 
						|
        return lookups
 | 
						|
    data = registry.lookups.get(lang)
 | 
						|
    for table in tables:
 | 
						|
        if table not in data:
 | 
						|
            if strict:
 | 
						|
                raise ValueError(Errors.E955.format(table=table, lang=lang))
 | 
						|
            language_data = {}  # type: ignore[var-annotated]
 | 
						|
        else:
 | 
						|
            language_data = load_language_data(data[table])  # type: ignore[assignment]
 | 
						|
        lookups.add_table(table, language_data)
 | 
						|
    return lookups
 | 
						|
 | 
						|
 | 
						|
class Table(OrderedDict):
 | 
						|
    """A table in the lookups. Subclass of builtin dict that implements a
 | 
						|
    slightly more consistent and unified API.
 | 
						|
 | 
						|
    Includes a Bloom filter to speed up missed lookups.
 | 
						|
    """
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def from_dict(cls, data: dict, name: Optional[str] = None) -> "Table":
 | 
						|
        """Initialize a new table from a dict.
 | 
						|
 | 
						|
        data (dict): The dictionary.
 | 
						|
        name (str): Optional table name for reference.
 | 
						|
 | 
						|
        DOCS: https://spacy.io/api/lookups#table.from_dict
 | 
						|
        """
 | 
						|
        self = cls(name=name)
 | 
						|
        self.update(data)
 | 
						|
        return self
 | 
						|
 | 
						|
    def __init__(self, name: Optional[str] = None, data: Optional[dict] = None) -> None:
 | 
						|
        """Initialize a new table.
 | 
						|
 | 
						|
        name (str): Optional table name for reference.
 | 
						|
        data (dict): Initial data, used to hint Bloom Filter.
 | 
						|
 | 
						|
        DOCS: https://spacy.io/api/lookups#table.init
 | 
						|
        """
 | 
						|
        OrderedDict.__init__(self)
 | 
						|
        self.name = name
 | 
						|
        # Assume a default size of 1M items
 | 
						|
        self.default_size = 1e6
 | 
						|
        size = max(len(data), 1) if data is not None else self.default_size
 | 
						|
        self.bloom = BloomFilter.from_error_rate(size)
 | 
						|
        if data:
 | 
						|
            self.update(data)
 | 
						|
 | 
						|
    def __setitem__(self, key: Union[str, int], value: Any) -> None:
 | 
						|
        """Set new key/value pair. String keys will be hashed.
 | 
						|
 | 
						|
        key (str / int): The key to set.
 | 
						|
        value: The value to set.
 | 
						|
        """
 | 
						|
        key = get_string_id(key)
 | 
						|
        OrderedDict.__setitem__(self, key, value)  # type:ignore[assignment]
 | 
						|
        self.bloom.add(key)
 | 
						|
 | 
						|
    def set(self, key: Union[str, int], value: Any) -> None:
 | 
						|
        """Set new key/value pair. String keys will be hashed.
 | 
						|
        Same as table[key] = value.
 | 
						|
 | 
						|
        key (str / int): The key to set.
 | 
						|
        value: The value to set.
 | 
						|
        """
 | 
						|
        self[key] = value
 | 
						|
 | 
						|
    def __getitem__(self, key: Union[str, int]) -> Any:
 | 
						|
        """Get the value for a given key. String keys will be hashed.
 | 
						|
 | 
						|
        key (str / int): The key to get.
 | 
						|
        RETURNS: The value.
 | 
						|
        """
 | 
						|
        key = get_string_id(key)
 | 
						|
        return OrderedDict.__getitem__(self, key)  # type:ignore[index]
 | 
						|
 | 
						|
    def get(self, key: Union[str, int], default: Optional[Any] = None) -> Any:
 | 
						|
        """Get the value for a given key. String keys will be hashed.
 | 
						|
 | 
						|
        key (str / int): The key to get.
 | 
						|
        default: The default value to return.
 | 
						|
        RETURNS: The value.
 | 
						|
        """
 | 
						|
        key = get_string_id(key)
 | 
						|
        return OrderedDict.get(self, key, default)  # type:ignore[arg-type]
 | 
						|
 | 
						|
    def __contains__(self, key: Union[str, int]) -> bool:  # type: ignore[override]
 | 
						|
        """Check whether a key is in the table. String keys will be hashed.
 | 
						|
 | 
						|
        key (str / int): The key to check.
 | 
						|
        RETURNS (bool): Whether the key is in the table.
 | 
						|
        """
 | 
						|
        key = get_string_id(key)
 | 
						|
        # This can give a false positive, so we need to check it after
 | 
						|
        if key not in self.bloom:
 | 
						|
            return False
 | 
						|
        return OrderedDict.__contains__(self, key)
 | 
						|
 | 
						|
    def to_bytes(self) -> bytes:
 | 
						|
        """Serialize table to a bytestring.
 | 
						|
 | 
						|
        RETURNS (bytes): The serialized table.
 | 
						|
 | 
						|
        DOCS: https://spacy.io/api/lookups#table.to_bytes
 | 
						|
        """
 | 
						|
        data = {
 | 
						|
            "name": self.name,
 | 
						|
            "dict": dict(self.items()),
 | 
						|
            "bloom": self.bloom.to_bytes(),
 | 
						|
        }
 | 
						|
        return srsly.msgpack_dumps(data)
 | 
						|
 | 
						|
    def from_bytes(self, bytes_data: bytes) -> "Table":
 | 
						|
        """Load a table from a bytestring.
 | 
						|
 | 
						|
        bytes_data (bytes): The data to load.
 | 
						|
        RETURNS (Table): The loaded table.
 | 
						|
 | 
						|
        DOCS: https://spacy.io/api/lookups#table.from_bytes
 | 
						|
        """
 | 
						|
        loaded = srsly.msgpack_loads(bytes_data)
 | 
						|
        data = loaded.get("dict", {})
 | 
						|
        self.name = loaded["name"]
 | 
						|
        self.bloom = BloomFilter().from_bytes(loaded["bloom"])
 | 
						|
        self.clear()
 | 
						|
        self.update(data)
 | 
						|
        return self
 | 
						|
 | 
						|
 | 
						|
class Lookups:
 | 
						|
    """Container for large lookup tables and dictionaries, e.g. lemmatization
 | 
						|
    data or tokenizer exception lists. Lookups are available via vocab.lookups,
 | 
						|
    so they can be accessed before the pipeline components are applied (e.g.
 | 
						|
    in the tokenizer and lemmatizer), as well as within the pipeline components
 | 
						|
    via doc.vocab.lookups.
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(self) -> None:
 | 
						|
        """Initialize the Lookups object.
 | 
						|
 | 
						|
        DOCS: https://spacy.io/api/lookups#init
 | 
						|
        """
 | 
						|
        self._tables: Dict[str, Table] = {}
 | 
						|
 | 
						|
    def __contains__(self, name: str) -> bool:
 | 
						|
        """Check if the lookups contain a table of a given name. Delegates to
 | 
						|
        Lookups.has_table.
 | 
						|
 | 
						|
        name (str): Name of the table.
 | 
						|
        RETURNS (bool): Whether a table of that name is in the lookups.
 | 
						|
        """
 | 
						|
        return self.has_table(name)
 | 
						|
 | 
						|
    def __len__(self) -> int:
 | 
						|
        """RETURNS (int): The number of tables in the lookups."""
 | 
						|
        return len(self._tables)
 | 
						|
 | 
						|
    @property
 | 
						|
    def tables(self) -> List[str]:
 | 
						|
        """RETURNS (List[str]): Names of all tables in the lookups."""
 | 
						|
        return list(self._tables.keys())
 | 
						|
 | 
						|
    def add_table(self, name: str, data: dict = SimpleFrozenDict()) -> Table:
 | 
						|
        """Add a new table to the lookups. Raises an error if the table exists.
 | 
						|
 | 
						|
        name (str): Unique name of table.
 | 
						|
        data (dict): Optional data to add to the table.
 | 
						|
        RETURNS (Table): The newly added table.
 | 
						|
 | 
						|
        DOCS: https://spacy.io/api/lookups#add_table
 | 
						|
        """
 | 
						|
        if name in self.tables:
 | 
						|
            raise ValueError(Errors.E158.format(name=name))
 | 
						|
        table = Table(name=name, data=data)
 | 
						|
        self._tables[name] = table
 | 
						|
        return table
 | 
						|
 | 
						|
    def set_table(self, name: str, table: Table) -> None:
 | 
						|
        """Set a table.
 | 
						|
 | 
						|
        name (str): Name of the table to set.
 | 
						|
        table (Table): The Table to set.
 | 
						|
 | 
						|
        DOCS: https://spacy.io/api/lookups#set_table
 | 
						|
        """
 | 
						|
        self._tables[name] = table
 | 
						|
 | 
						|
    def get_table(self, name: str, default: Any = UNSET) -> Table:
 | 
						|
        """Get a table. Raises an error if the table doesn't exist and no
 | 
						|
        default value is provided.
 | 
						|
 | 
						|
        name (str): Name of the table.
 | 
						|
        default (Any): Optional default value to return if table doesn't exist.
 | 
						|
        RETURNS (Table): The table.
 | 
						|
 | 
						|
        DOCS: https://spacy.io/api/lookups#get_table
 | 
						|
        """
 | 
						|
        if name not in self._tables:
 | 
						|
            if default == UNSET:
 | 
						|
                raise KeyError(Errors.E159.format(name=name, tables=self.tables))
 | 
						|
            return default
 | 
						|
        return self._tables[name]
 | 
						|
 | 
						|
    def remove_table(self, name: str) -> Table:
 | 
						|
        """Remove a table. Raises an error if the table doesn't exist.
 | 
						|
 | 
						|
        name (str): Name of the table to remove.
 | 
						|
        RETURNS (Table): The removed table.
 | 
						|
 | 
						|
        DOCS: https://spacy.io/api/lookups#remove_table
 | 
						|
        """
 | 
						|
        if name not in self._tables:
 | 
						|
            raise KeyError(Errors.E159.format(name=name, tables=self.tables))
 | 
						|
        return self._tables.pop(name)
 | 
						|
 | 
						|
    def has_table(self, name: str) -> bool:
 | 
						|
        """Check if the lookups contain a table of a given name.
 | 
						|
 | 
						|
        name (str): Name of the table.
 | 
						|
        RETURNS (bool): Whether a table of that name exists.
 | 
						|
 | 
						|
        DOCS: https://spacy.io/api/lookups#has_table
 | 
						|
        """
 | 
						|
        return name in self._tables
 | 
						|
 | 
						|
    def to_bytes(self, **kwargs) -> bytes:
 | 
						|
        """Serialize the lookups to a bytestring.
 | 
						|
 | 
						|
        RETURNS (bytes): The serialized Lookups.
 | 
						|
 | 
						|
        DOCS: https://spacy.io/api/lookups#to_bytes
 | 
						|
        """
 | 
						|
        return srsly.msgpack_dumps(self._tables)
 | 
						|
 | 
						|
    def from_bytes(self, bytes_data: bytes, **kwargs) -> "Lookups":
 | 
						|
        """Load the lookups from a bytestring.
 | 
						|
 | 
						|
        bytes_data (bytes): The data to load.
 | 
						|
        RETURNS (Lookups): The loaded Lookups.
 | 
						|
 | 
						|
        DOCS: https://spacy.io/api/lookups#from_bytes
 | 
						|
        """
 | 
						|
        self._tables = {}
 | 
						|
        for key, value in srsly.msgpack_loads(bytes_data).items():
 | 
						|
            self._tables[key] = Table(key, value)
 | 
						|
        return self
 | 
						|
 | 
						|
    def to_disk(
 | 
						|
        self, path: Union[str, Path], filename: str = "lookups.bin", **kwargs
 | 
						|
    ) -> None:
 | 
						|
        """Save the lookups to a directory as lookups.bin. Expects a path to a
 | 
						|
        directory, which will be created if it doesn't exist.
 | 
						|
 | 
						|
        path (str / Path): The file path.
 | 
						|
 | 
						|
        DOCS: https://spacy.io/api/lookups#to_disk
 | 
						|
        """
 | 
						|
        path = ensure_path(path)
 | 
						|
        if not path.exists():
 | 
						|
            path.mkdir()
 | 
						|
        filepath = path / filename
 | 
						|
        with filepath.open("wb") as file_:
 | 
						|
            file_.write(self.to_bytes())
 | 
						|
 | 
						|
    def from_disk(
 | 
						|
        self, path: Union[str, Path], filename: str = "lookups.bin", **kwargs
 | 
						|
    ) -> "Lookups":
 | 
						|
        """Load lookups from a directory containing a lookups.bin. Will skip
 | 
						|
        loading if the file doesn't exist.
 | 
						|
 | 
						|
        path (str / Path): The directory path.
 | 
						|
        RETURNS (Lookups): The loaded lookups.
 | 
						|
 | 
						|
        DOCS: https://spacy.io/api/lookups#from_disk
 | 
						|
        """
 | 
						|
        path = ensure_path(path)
 | 
						|
        filepath = path / filename
 | 
						|
        if filepath.exists():
 | 
						|
            with filepath.open("rb") as file_:
 | 
						|
                data = file_.read()
 | 
						|
            return self.from_bytes(data)
 | 
						|
        return self
 |