mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 09:57:26 +03:00 
			
		
		
		
	Upd initialize args
This commit is contained in:
		
						commit
						9c8b2524fe
					
				| 
						 | 
				
			
			@ -5,11 +5,35 @@ from wasabi import msg
 | 
			
		|||
import typer
 | 
			
		||||
 | 
			
		||||
from .. import util
 | 
			
		||||
from ..training.initialize import init_nlp
 | 
			
		||||
from ..training.initialize import init_nlp, convert_vectors
 | 
			
		||||
from ._util import init_cli, Arg, Opt, parse_config_overrides, show_validation_error
 | 
			
		||||
from ._util import import_code, setup_gpu
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@init_cli.command("vectors")
 | 
			
		||||
def init_vectors_cli(
 | 
			
		||||
    # fmt: off
 | 
			
		||||
    lang: str = Arg(..., help="The language of the nlp object to create"),
 | 
			
		||||
    vectors_loc: Path = Arg(..., help="Vectors file in Word2Vec format", exists=True),
 | 
			
		||||
    output_dir: Path = Arg(..., help="Pipeline output directory"),
 | 
			
		||||
    prune: int = Opt(-1, "--prune", "-p", help="Optional number of vectors to prune to"),
 | 
			
		||||
    truncate: int = Opt(0, "--truncate", "-t", help="Optional number of vectors to truncate to when reading in vectors file"),
 | 
			
		||||
    name: Optional[str] = Opt(None, "--name", "-n", help="Optional name for the word vectors, e.g. en_core_web_lg.vectors"),
 | 
			
		||||
    # fmt: on
 | 
			
		||||
):
 | 
			
		||||
    msg.info(f"Creating blank nlp object for language '{lang}'")
 | 
			
		||||
    nlp = util.get_lang_class(lang)()
 | 
			
		||||
    convert_vectors(
 | 
			
		||||
        nlp, vectors_loc, truncate=truncate, prune=prune, name=name, silent=False
 | 
			
		||||
    )
 | 
			
		||||
    nlp.to_disk(output_dir)
 | 
			
		||||
    msg.good(
 | 
			
		||||
        "Saved nlp object with vectors to output directory. You can now use the "
 | 
			
		||||
        "path to it in your config as the 'vectors' setting in [initialize.vocab].",
 | 
			
		||||
        output_dir,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@init_cli.command(
 | 
			
		||||
    "nlp",
 | 
			
		||||
    context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -27,7 +27,7 @@ from .lang.punctuation import TOKENIZER_INFIXES
 | 
			
		|||
from .tokens import Doc
 | 
			
		||||
from .tokenizer import Tokenizer
 | 
			
		||||
from .errors import Errors, Warnings
 | 
			
		||||
from .schemas import ConfigSchema, ConfigSchemaNlp
 | 
			
		||||
from .schemas import ConfigSchema, ConfigSchemaNlp, validate_init_settings
 | 
			
		||||
from .git_info import GIT_VERSION
 | 
			
		||||
from . import util
 | 
			
		||||
from . import about
 | 
			
		||||
| 
						 | 
				
			
			@ -1161,8 +1161,10 @@ class Language:
 | 
			
		|||
    def initialize(
 | 
			
		||||
        self,
 | 
			
		||||
        get_examples: Optional[Callable[[], Iterable[Example]]] = None,
 | 
			
		||||
        sgd: Optional[Optimizer]=None
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        *,
 | 
			
		||||
        settings: Dict[str, Dict[str, Any]] = SimpleFrozenDict(),
 | 
			
		||||
        sgd: Optional[Optimizer] = None,
 | 
			
		||||
    ) -> Optimizer:
 | 
			
		||||
        """Initialize the pipe for training, using data examples if available.
 | 
			
		||||
 | 
			
		||||
        get_examples (Callable[[], Iterable[Example]]): Optional function that
 | 
			
		||||
| 
						 | 
				
			
			@ -1200,10 +1202,28 @@ class Language:
 | 
			
		|||
        if self.vocab.vectors.data.shape[1] >= 1:
 | 
			
		||||
            ops = get_current_ops()
 | 
			
		||||
            self.vocab.vectors.data = ops.asarray(self.vocab.vectors.data)
 | 
			
		||||
        self._optimizer = sgd
 | 
			
		||||
        if hasattr(self.tokenizer, "initialize"):
 | 
			
		||||
            tok_settings = settings.get("tokenizer", {})
 | 
			
		||||
            tok_settings = validate_init_settings(
 | 
			
		||||
                self.tokenizer.initialize,
 | 
			
		||||
                tok_settings,
 | 
			
		||||
                section="tokenizer",
 | 
			
		||||
                name="tokenizer",
 | 
			
		||||
            )
 | 
			
		||||
            self.tokenizer.initialize(get_examples, nlp=self, **tok_settings)
 | 
			
		||||
        proc_settings = settings.get("components", {})
 | 
			
		||||
        for name, proc in self.pipeline:
 | 
			
		||||
            if hasattr(proc, "initialize"):
 | 
			
		||||
                p_settings = proc_settings.get(name, {})
 | 
			
		||||
                p_settings = validate_init_settings(
 | 
			
		||||
                    proc.initialize, p_settings, section="components", name=name
 | 
			
		||||
                )
 | 
			
		||||
                proc.initialize(
 | 
			
		||||
                    get_examples, pipeline=self.pipeline
 | 
			
		||||
                    get_examples,
 | 
			
		||||
                    pipeline=self.pipeline,
 | 
			
		||||
                    **p_settings,
 | 
			
		||||
                )
 | 
			
		||||
        self._link_components()
 | 
			
		||||
        if sgd is not None:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,4 +1,4 @@
 | 
			
		|||
# cython: infer_types=True, cdivision=True, boundscheck=False
 | 
			
		||||
# cython: infer_types=True, cdivision=True, boundscheck=False, binding=True
 | 
			
		||||
from __future__ import print_function
 | 
			
		||||
from cymem.cymem cimport Pool
 | 
			
		||||
cimport numpy as np
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,11 +1,13 @@
 | 
			
		|||
from typing import Dict, List, Union, Optional, Any, Callable, Type, Tuple
 | 
			
		||||
from typing import Iterable, TypeVar, TYPE_CHECKING
 | 
			
		||||
from enum import Enum
 | 
			
		||||
from pydantic import BaseModel, Field, ValidationError, validator
 | 
			
		||||
from pydantic import BaseModel, Field, ValidationError, validator, create_model
 | 
			
		||||
from pydantic import StrictStr, StrictInt, StrictFloat, StrictBool
 | 
			
		||||
from pydantic.main import ModelMetaclass
 | 
			
		||||
from thinc.api import Optimizer, ConfigValidationError
 | 
			
		||||
from thinc.config import Promise
 | 
			
		||||
from collections import defaultdict
 | 
			
		||||
from thinc.api import Optimizer
 | 
			
		||||
import inspect
 | 
			
		||||
 | 
			
		||||
from .attrs import NAMES
 | 
			
		||||
from .lookups import Lookups
 | 
			
		||||
| 
						 | 
				
			
			@ -43,6 +45,93 @@ def validate(schema: Type[BaseModel], obj: Dict[str, Any]) -> List[str]:
 | 
			
		|||
        return [f"[{loc}] {', '.join(msg)}" for loc, msg in data.items()]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Initialization
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ArgSchemaConfig:
 | 
			
		||||
    extra = "forbid"
 | 
			
		||||
    arbitrary_types_allowed = True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ArgSchemaConfigExtra:
 | 
			
		||||
    extra = "forbid"
 | 
			
		||||
    arbitrary_types_allowed = True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_arg_model(
 | 
			
		||||
    func: Callable,
 | 
			
		||||
    *,
 | 
			
		||||
    exclude: Iterable[str] = tuple(),
 | 
			
		||||
    name: str = "ArgModel",
 | 
			
		||||
    strict: bool = True,
 | 
			
		||||
) -> ModelMetaclass:
 | 
			
		||||
    """Generate a pydantic model for function arguments.
 | 
			
		||||
 | 
			
		||||
    func (Callable): The function to generate the schema for.
 | 
			
		||||
    exclude (Iterable[str]): Parameter names to ignore.
 | 
			
		||||
    name (str): Name of created model class.
 | 
			
		||||
    strict (bool): Don't allow extra arguments if no variable keyword arguments
 | 
			
		||||
        are allowed on the function.
 | 
			
		||||
    RETURNS (ModelMetaclass): A pydantic model.
 | 
			
		||||
    """
 | 
			
		||||
    sig_args = {}
 | 
			
		||||
    try:
 | 
			
		||||
        sig = inspect.signature(func)
 | 
			
		||||
    except ValueError:
 | 
			
		||||
        # Typically happens if the method is part of a Cython module without
 | 
			
		||||
        # binding=True. Here we just use an empty model that allows everything.
 | 
			
		||||
        return create_model(name, __config__=ArgSchemaConfigExtra)
 | 
			
		||||
    has_variable = False
 | 
			
		||||
    for param in sig.parameters.values():
 | 
			
		||||
        if param.name in exclude:
 | 
			
		||||
            continue
 | 
			
		||||
        if param.kind == param.VAR_KEYWORD:
 | 
			
		||||
            # The function allows variable keyword arguments so we shouldn't
 | 
			
		||||
            # include **kwargs etc. in the schema and switch to non-strict
 | 
			
		||||
            # mode and pass through all other values
 | 
			
		||||
            has_variable = True
 | 
			
		||||
            continue
 | 
			
		||||
        # If no annotation is specified assume it's anything
 | 
			
		||||
        annotation = param.annotation if param.annotation != param.empty else Any
 | 
			
		||||
        # If no default value is specified assume that it's required
 | 
			
		||||
        default = param.default if param.default != param.empty else ...
 | 
			
		||||
        sig_args[param.name] = (annotation, default)
 | 
			
		||||
    is_strict = strict and not has_variable
 | 
			
		||||
    sig_args["__config__"] = ArgSchemaConfig if is_strict else ArgSchemaConfigExtra
 | 
			
		||||
    return create_model(name, **sig_args)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def validate_init_settings(
 | 
			
		||||
    func: Callable,
 | 
			
		||||
    settings: Dict[str, Any],
 | 
			
		||||
    *,
 | 
			
		||||
    section: Optional[str] = None,
 | 
			
		||||
    name: str = "",
 | 
			
		||||
    exclude: Iterable[str] = ("get_examples", "nlp", "pipeline", "sgd"),
 | 
			
		||||
) -> Dict[str, Any]:
 | 
			
		||||
    """Validate initialization settings against the expected arguments in
 | 
			
		||||
    the method signature. Will parse values if possible (e.g. int to string)
 | 
			
		||||
    and return the updated settings dict. Will raise a ConfigValidationError
 | 
			
		||||
    if types don't match or required values are missing.
 | 
			
		||||
 | 
			
		||||
    func (Callable): The initialize method of a given component etc.
 | 
			
		||||
    settings (Dict[str, Any]): The settings from the repsective [initialize] block.
 | 
			
		||||
    section (str): Initialize section, for error message.
 | 
			
		||||
    name (str): Name of the block in the section.
 | 
			
		||||
    exclude (Iterable[str]): Parameter names to exclude from schema.
 | 
			
		||||
    RETURNS (Dict[str, Any]): The validated settings.
 | 
			
		||||
    """
 | 
			
		||||
    schema = get_arg_model(func, exclude=exclude, name="InitArgModel")
 | 
			
		||||
    try:
 | 
			
		||||
        return schema(**settings).dict()
 | 
			
		||||
    except ValidationError as e:
 | 
			
		||||
        block = "initialize" if not section else f"initialize.{section}"
 | 
			
		||||
        title = f"Error validating initialization settings in [{block}]"
 | 
			
		||||
        raise ConfigValidationError(
 | 
			
		||||
            title=title, errors=e.errors(), config=settings, parent=name,
 | 
			
		||||
        ) from None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Matcher token patterns
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,13 +1,19 @@
 | 
			
		|||
from typing import Union, Dict, Optional, Any, List
 | 
			
		||||
from typing import Union, Dict, Optional, Any, List, IO
 | 
			
		||||
from thinc.api import Config, fix_random_seed, set_gpu_allocator
 | 
			
		||||
from thinc.api import ConfigValidationError
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from wasabi import Printer
 | 
			
		||||
import srsly
 | 
			
		||||
import numpy
 | 
			
		||||
import tarfile
 | 
			
		||||
import gzip
 | 
			
		||||
import zipfile
 | 
			
		||||
import tqdm
 | 
			
		||||
 | 
			
		||||
from .loop import create_before_to_disk_callback
 | 
			
		||||
from ..language import Language
 | 
			
		||||
from ..lookups import Lookups
 | 
			
		||||
from ..vectors import Vectors
 | 
			
		||||
from ..errors import Errors
 | 
			
		||||
from ..schemas import ConfigSchemaTraining, ConfigSchemaInit, ConfigSchemaPretrain
 | 
			
		||||
from ..util import registry, load_model_from_config, resolve_dot_names
 | 
			
		||||
| 
						 | 
				
			
			@ -49,8 +55,8 @@ def init_nlp(config: Config, *, use_gpu: int = -1, silent: bool = True) -> Langu
 | 
			
		|||
            msg.info(f"Resuming training for: {resume_components}")
 | 
			
		||||
            nlp.resume_training(sgd=optimizer)
 | 
			
		||||
    with nlp.select_pipes(disable=[*frozen_components, *resume_components]):
 | 
			
		||||
        nlp.initialize(lambda: train_corpus(nlp), sgd=optimizer)
 | 
			
		||||
        msg.good(f"Initialized pipeline components")
 | 
			
		||||
        nlp.initialize(lambda: train_corpus(nlp), sgd=optimizer, settings=I)
 | 
			
		||||
        msg.good("Initialized pipeline components")
 | 
			
		||||
    # Verify the config after calling 'initialize' to ensure labels
 | 
			
		||||
    # are properly initialized
 | 
			
		||||
    verify_config(nlp)
 | 
			
		||||
| 
						 | 
				
			
			@ -103,7 +109,7 @@ def init_vocab(
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
def load_vectors_into_model(
 | 
			
		||||
    nlp: "Language", name: Union[str, Path], *, add_strings: bool = True
 | 
			
		||||
    nlp: Language, name: Union[str, Path], *, add_strings: bool = True
 | 
			
		||||
) -> None:
 | 
			
		||||
    """Load word vectors from an installed model or path into a model instance."""
 | 
			
		||||
    try:
 | 
			
		||||
| 
						 | 
				
			
			@ -202,3 +208,104 @@ def get_sourced_components(config: Union[Dict[str, Any], Config]) -> List[str]:
 | 
			
		|||
        for name, cfg in config.get("components", {}).items()
 | 
			
		||||
        if "factory" not in cfg and "source" in cfg
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_vectors(
 | 
			
		||||
    nlp: Language,
 | 
			
		||||
    vectors_loc: Optional[Path],
 | 
			
		||||
    *,
 | 
			
		||||
    truncate: int,
 | 
			
		||||
    prune: int,
 | 
			
		||||
    name: Optional[str] = None,
 | 
			
		||||
    silent: bool = True,
 | 
			
		||||
) -> None:
 | 
			
		||||
    msg = Printer(no_print=silent)
 | 
			
		||||
    vectors_loc = ensure_path(vectors_loc)
 | 
			
		||||
    if vectors_loc and vectors_loc.parts[-1].endswith(".npz"):
 | 
			
		||||
        nlp.vocab.vectors = Vectors(data=numpy.load(vectors_loc.open("rb")))
 | 
			
		||||
        for lex in nlp.vocab:
 | 
			
		||||
            if lex.rank and lex.rank != OOV_RANK:
 | 
			
		||||
                nlp.vocab.vectors.add(lex.orth, row=lex.rank)
 | 
			
		||||
    else:
 | 
			
		||||
        if vectors_loc:
 | 
			
		||||
            with msg.loading(f"Reading vectors from {vectors_loc}"):
 | 
			
		||||
                vectors_data, vector_keys = read_vectors(vectors_loc, truncate)
 | 
			
		||||
            msg.good(f"Loaded vectors from {vectors_loc}")
 | 
			
		||||
        else:
 | 
			
		||||
            vectors_data, vector_keys = (None, None)
 | 
			
		||||
        if vector_keys is not None:
 | 
			
		||||
            for word in vector_keys:
 | 
			
		||||
                if word not in nlp.vocab:
 | 
			
		||||
                    nlp.vocab[word]
 | 
			
		||||
        if vectors_data is not None:
 | 
			
		||||
            nlp.vocab.vectors = Vectors(data=vectors_data, keys=vector_keys)
 | 
			
		||||
    if name is None:
 | 
			
		||||
        # TODO: Is this correct? Does this matter?
 | 
			
		||||
        nlp.vocab.vectors.name = f"{nlp.meta['lang']}_{nlp.meta['name']}.vectors"
 | 
			
		||||
    else:
 | 
			
		||||
        nlp.vocab.vectors.name = name
 | 
			
		||||
    nlp.meta["vectors"]["name"] = nlp.vocab.vectors.name
 | 
			
		||||
    if prune >= 1:
 | 
			
		||||
        nlp.vocab.prune_vectors(prune)
 | 
			
		||||
    msg.good(f"Successfully converted {len(nlp.vocab.vectors)} vectors")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def read_vectors(vectors_loc: Path, truncate_vectors: int):
 | 
			
		||||
    f = open_file(vectors_loc)
 | 
			
		||||
    f = ensure_shape(f)
 | 
			
		||||
    shape = tuple(int(size) for size in next(f).split())
 | 
			
		||||
    if truncate_vectors >= 1:
 | 
			
		||||
        shape = (truncate_vectors, shape[1])
 | 
			
		||||
    vectors_data = numpy.zeros(shape=shape, dtype="f")
 | 
			
		||||
    vectors_keys = []
 | 
			
		||||
    for i, line in enumerate(tqdm.tqdm(f)):
 | 
			
		||||
        line = line.rstrip()
 | 
			
		||||
        pieces = line.rsplit(" ", vectors_data.shape[1])
 | 
			
		||||
        word = pieces.pop(0)
 | 
			
		||||
        if len(pieces) != vectors_data.shape[1]:
 | 
			
		||||
            raise ValueError(Errors.E094.format(line_num=i, loc=vectors_loc))
 | 
			
		||||
        vectors_data[i] = numpy.asarray(pieces, dtype="f")
 | 
			
		||||
        vectors_keys.append(word)
 | 
			
		||||
        if i == truncate_vectors - 1:
 | 
			
		||||
            break
 | 
			
		||||
    return vectors_data, vectors_keys
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def open_file(loc: Union[str, Path]) -> IO:
 | 
			
		||||
    """Handle .gz, .tar.gz or unzipped files"""
 | 
			
		||||
    loc = ensure_path(loc)
 | 
			
		||||
    if tarfile.is_tarfile(str(loc)):
 | 
			
		||||
        return tarfile.open(str(loc), "r:gz")
 | 
			
		||||
    elif loc.parts[-1].endswith("gz"):
 | 
			
		||||
        return (line.decode("utf8") for line in gzip.open(str(loc), "r"))
 | 
			
		||||
    elif loc.parts[-1].endswith("zip"):
 | 
			
		||||
        zip_file = zipfile.ZipFile(str(loc))
 | 
			
		||||
        names = zip_file.namelist()
 | 
			
		||||
        file_ = zip_file.open(names[0])
 | 
			
		||||
        return (line.decode("utf8") for line in file_)
 | 
			
		||||
    else:
 | 
			
		||||
        return loc.open("r", encoding="utf8")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def ensure_shape(lines):
 | 
			
		||||
    """Ensure that the first line of the data is the vectors shape.
 | 
			
		||||
    If it's not, we read in the data and output the shape as the first result,
 | 
			
		||||
    so that the reader doesn't have to deal with the problem.
 | 
			
		||||
    """
 | 
			
		||||
    first_line = next(lines)
 | 
			
		||||
    try:
 | 
			
		||||
        shape = tuple(int(size) for size in first_line.split())
 | 
			
		||||
    except ValueError:
 | 
			
		||||
        shape = None
 | 
			
		||||
    if shape is not None:
 | 
			
		||||
        # All good, give the data
 | 
			
		||||
        yield first_line
 | 
			
		||||
        yield from lines
 | 
			
		||||
    else:
 | 
			
		||||
        # Figure out the shape, make it the first value, and then give the
 | 
			
		||||
        # rest of the data.
 | 
			
		||||
        width = len(first_line.split()) - 1
 | 
			
		||||
        captured = [first_line] + list(lines)
 | 
			
		||||
        length = len(captured)
 | 
			
		||||
        yield f"{length} {width}"
 | 
			
		||||
        yield from captured
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user