mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-13 18:56:36 +03:00
Upd initialize args
This commit is contained in:
commit
9c8b2524fe
|
@ -5,11 +5,35 @@ from wasabi import msg
|
||||||
import typer
|
import typer
|
||||||
|
|
||||||
from .. import util
|
from .. import util
|
||||||
from ..training.initialize import init_nlp
|
from ..training.initialize import init_nlp, convert_vectors
|
||||||
from ._util import init_cli, Arg, Opt, parse_config_overrides, show_validation_error
|
from ._util import init_cli, Arg, Opt, parse_config_overrides, show_validation_error
|
||||||
from ._util import import_code, setup_gpu
|
from ._util import import_code, setup_gpu
|
||||||
|
|
||||||
|
|
||||||
|
@init_cli.command("vectors")
|
||||||
|
def init_vectors_cli(
|
||||||
|
# fmt: off
|
||||||
|
lang: str = Arg(..., help="The language of the nlp object to create"),
|
||||||
|
vectors_loc: Path = Arg(..., help="Vectors file in Word2Vec format", exists=True),
|
||||||
|
output_dir: Path = Arg(..., help="Pipeline output directory"),
|
||||||
|
prune: int = Opt(-1, "--prune", "-p", help="Optional number of vectors to prune to"),
|
||||||
|
truncate: int = Opt(0, "--truncate", "-t", help="Optional number of vectors to truncate to when reading in vectors file"),
|
||||||
|
name: Optional[str] = Opt(None, "--name", "-n", help="Optional name for the word vectors, e.g. en_core_web_lg.vectors"),
|
||||||
|
# fmt: on
|
||||||
|
):
|
||||||
|
msg.info(f"Creating blank nlp object for language '{lang}'")
|
||||||
|
nlp = util.get_lang_class(lang)()
|
||||||
|
convert_vectors(
|
||||||
|
nlp, vectors_loc, truncate=truncate, prune=prune, name=name, silent=False
|
||||||
|
)
|
||||||
|
nlp.to_disk(output_dir)
|
||||||
|
msg.good(
|
||||||
|
"Saved nlp object with vectors to output directory. You can now use the "
|
||||||
|
"path to it in your config as the 'vectors' setting in [initialize.vocab].",
|
||||||
|
output_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@init_cli.command(
|
@init_cli.command(
|
||||||
"nlp",
|
"nlp",
|
||||||
context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
|
context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
|
||||||
|
|
|
@ -27,7 +27,7 @@ from .lang.punctuation import TOKENIZER_INFIXES
|
||||||
from .tokens import Doc
|
from .tokens import Doc
|
||||||
from .tokenizer import Tokenizer
|
from .tokenizer import Tokenizer
|
||||||
from .errors import Errors, Warnings
|
from .errors import Errors, Warnings
|
||||||
from .schemas import ConfigSchema, ConfigSchemaNlp
|
from .schemas import ConfigSchema, ConfigSchemaNlp, validate_init_settings
|
||||||
from .git_info import GIT_VERSION
|
from .git_info import GIT_VERSION
|
||||||
from . import util
|
from . import util
|
||||||
from . import about
|
from . import about
|
||||||
|
@ -1161,8 +1161,10 @@ class Language:
|
||||||
def initialize(
|
def initialize(
|
||||||
self,
|
self,
|
||||||
get_examples: Optional[Callable[[], Iterable[Example]]] = None,
|
get_examples: Optional[Callable[[], Iterable[Example]]] = None,
|
||||||
sgd: Optional[Optimizer]=None
|
*,
|
||||||
) -> None:
|
settings: Dict[str, Dict[str, Any]] = SimpleFrozenDict(),
|
||||||
|
sgd: Optional[Optimizer] = None,
|
||||||
|
) -> Optimizer:
|
||||||
"""Initialize the pipe for training, using data examples if available.
|
"""Initialize the pipe for training, using data examples if available.
|
||||||
|
|
||||||
get_examples (Callable[[], Iterable[Example]]): Optional function that
|
get_examples (Callable[[], Iterable[Example]]): Optional function that
|
||||||
|
@ -1200,10 +1202,28 @@ class Language:
|
||||||
if self.vocab.vectors.data.shape[1] >= 1:
|
if self.vocab.vectors.data.shape[1] >= 1:
|
||||||
ops = get_current_ops()
|
ops = get_current_ops()
|
||||||
self.vocab.vectors.data = ops.asarray(self.vocab.vectors.data)
|
self.vocab.vectors.data = ops.asarray(self.vocab.vectors.data)
|
||||||
|
self._optimizer = sgd
|
||||||
|
if hasattr(self.tokenizer, "initialize"):
|
||||||
|
tok_settings = settings.get("tokenizer", {})
|
||||||
|
tok_settings = validate_init_settings(
|
||||||
|
self.tokenizer.initialize,
|
||||||
|
tok_settings,
|
||||||
|
section="tokenizer",
|
||||||
|
name="tokenizer",
|
||||||
|
)
|
||||||
|
self.tokenizer.initialize(get_examples, nlp=self, **tok_settings)
|
||||||
|
proc_settings = settings.get("components", {})
|
||||||
for name, proc in self.pipeline:
|
for name, proc in self.pipeline:
|
||||||
if hasattr(proc, "initialize"):
|
if hasattr(proc, "initialize"):
|
||||||
|
p_settings = proc_settings.get(name, {})
|
||||||
|
p_settings = validate_init_settings(
|
||||||
|
proc.initialize, p_settings, section="components", name=name
|
||||||
|
)
|
||||||
proc.initialize(
|
proc.initialize(
|
||||||
get_examples, pipeline=self.pipeline
|
get_examples, pipeline=self.pipeline
|
||||||
|
get_examples,
|
||||||
|
pipeline=self.pipeline,
|
||||||
|
**p_settings,
|
||||||
)
|
)
|
||||||
self._link_components()
|
self._link_components()
|
||||||
if sgd is not None:
|
if sgd is not None:
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# cython: infer_types=True, cdivision=True, boundscheck=False
|
# cython: infer_types=True, cdivision=True, boundscheck=False, binding=True
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
from cymem.cymem cimport Pool
|
from cymem.cymem cimport Pool
|
||||||
cimport numpy as np
|
cimport numpy as np
|
||||||
|
|
|
@ -1,11 +1,13 @@
|
||||||
from typing import Dict, List, Union, Optional, Any, Callable, Type, Tuple
|
from typing import Dict, List, Union, Optional, Any, Callable, Type, Tuple
|
||||||
from typing import Iterable, TypeVar, TYPE_CHECKING
|
from typing import Iterable, TypeVar, TYPE_CHECKING
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pydantic import BaseModel, Field, ValidationError, validator
|
from pydantic import BaseModel, Field, ValidationError, validator, create_model
|
||||||
from pydantic import StrictStr, StrictInt, StrictFloat, StrictBool
|
from pydantic import StrictStr, StrictInt, StrictFloat, StrictBool
|
||||||
|
from pydantic.main import ModelMetaclass
|
||||||
|
from thinc.api import Optimizer, ConfigValidationError
|
||||||
from thinc.config import Promise
|
from thinc.config import Promise
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from thinc.api import Optimizer
|
import inspect
|
||||||
|
|
||||||
from .attrs import NAMES
|
from .attrs import NAMES
|
||||||
from .lookups import Lookups
|
from .lookups import Lookups
|
||||||
|
@ -43,6 +45,93 @@ def validate(schema: Type[BaseModel], obj: Dict[str, Any]) -> List[str]:
|
||||||
return [f"[{loc}] {', '.join(msg)}" for loc, msg in data.items()]
|
return [f"[{loc}] {', '.join(msg)}" for loc, msg in data.items()]
|
||||||
|
|
||||||
|
|
||||||
|
# Initialization
|
||||||
|
|
||||||
|
|
||||||
|
class ArgSchemaConfig:
|
||||||
|
extra = "forbid"
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
|
||||||
|
class ArgSchemaConfigExtra:
|
||||||
|
extra = "forbid"
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
|
||||||
|
def get_arg_model(
|
||||||
|
func: Callable,
|
||||||
|
*,
|
||||||
|
exclude: Iterable[str] = tuple(),
|
||||||
|
name: str = "ArgModel",
|
||||||
|
strict: bool = True,
|
||||||
|
) -> ModelMetaclass:
|
||||||
|
"""Generate a pydantic model for function arguments.
|
||||||
|
|
||||||
|
func (Callable): The function to generate the schema for.
|
||||||
|
exclude (Iterable[str]): Parameter names to ignore.
|
||||||
|
name (str): Name of created model class.
|
||||||
|
strict (bool): Don't allow extra arguments if no variable keyword arguments
|
||||||
|
are allowed on the function.
|
||||||
|
RETURNS (ModelMetaclass): A pydantic model.
|
||||||
|
"""
|
||||||
|
sig_args = {}
|
||||||
|
try:
|
||||||
|
sig = inspect.signature(func)
|
||||||
|
except ValueError:
|
||||||
|
# Typically happens if the method is part of a Cython module without
|
||||||
|
# binding=True. Here we just use an empty model that allows everything.
|
||||||
|
return create_model(name, __config__=ArgSchemaConfigExtra)
|
||||||
|
has_variable = False
|
||||||
|
for param in sig.parameters.values():
|
||||||
|
if param.name in exclude:
|
||||||
|
continue
|
||||||
|
if param.kind == param.VAR_KEYWORD:
|
||||||
|
# The function allows variable keyword arguments so we shouldn't
|
||||||
|
# include **kwargs etc. in the schema and switch to non-strict
|
||||||
|
# mode and pass through all other values
|
||||||
|
has_variable = True
|
||||||
|
continue
|
||||||
|
# If no annotation is specified assume it's anything
|
||||||
|
annotation = param.annotation if param.annotation != param.empty else Any
|
||||||
|
# If no default value is specified assume that it's required
|
||||||
|
default = param.default if param.default != param.empty else ...
|
||||||
|
sig_args[param.name] = (annotation, default)
|
||||||
|
is_strict = strict and not has_variable
|
||||||
|
sig_args["__config__"] = ArgSchemaConfig if is_strict else ArgSchemaConfigExtra
|
||||||
|
return create_model(name, **sig_args)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_init_settings(
|
||||||
|
func: Callable,
|
||||||
|
settings: Dict[str, Any],
|
||||||
|
*,
|
||||||
|
section: Optional[str] = None,
|
||||||
|
name: str = "",
|
||||||
|
exclude: Iterable[str] = ("get_examples", "nlp", "pipeline", "sgd"),
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Validate initialization settings against the expected arguments in
|
||||||
|
the method signature. Will parse values if possible (e.g. int to string)
|
||||||
|
and return the updated settings dict. Will raise a ConfigValidationError
|
||||||
|
if types don't match or required values are missing.
|
||||||
|
|
||||||
|
func (Callable): The initialize method of a given component etc.
|
||||||
|
settings (Dict[str, Any]): The settings from the repsective [initialize] block.
|
||||||
|
section (str): Initialize section, for error message.
|
||||||
|
name (str): Name of the block in the section.
|
||||||
|
exclude (Iterable[str]): Parameter names to exclude from schema.
|
||||||
|
RETURNS (Dict[str, Any]): The validated settings.
|
||||||
|
"""
|
||||||
|
schema = get_arg_model(func, exclude=exclude, name="InitArgModel")
|
||||||
|
try:
|
||||||
|
return schema(**settings).dict()
|
||||||
|
except ValidationError as e:
|
||||||
|
block = "initialize" if not section else f"initialize.{section}"
|
||||||
|
title = f"Error validating initialization settings in [{block}]"
|
||||||
|
raise ConfigValidationError(
|
||||||
|
title=title, errors=e.errors(), config=settings, parent=name,
|
||||||
|
) from None
|
||||||
|
|
||||||
|
|
||||||
# Matcher token patterns
|
# Matcher token patterns
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,13 +1,19 @@
|
||||||
from typing import Union, Dict, Optional, Any, List
|
from typing import Union, Dict, Optional, Any, List, IO
|
||||||
from thinc.api import Config, fix_random_seed, set_gpu_allocator
|
from thinc.api import Config, fix_random_seed, set_gpu_allocator
|
||||||
from thinc.api import ConfigValidationError
|
from thinc.api import ConfigValidationError
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from wasabi import Printer
|
from wasabi import Printer
|
||||||
import srsly
|
import srsly
|
||||||
|
import numpy
|
||||||
|
import tarfile
|
||||||
|
import gzip
|
||||||
|
import zipfile
|
||||||
|
import tqdm
|
||||||
|
|
||||||
from .loop import create_before_to_disk_callback
|
from .loop import create_before_to_disk_callback
|
||||||
from ..language import Language
|
from ..language import Language
|
||||||
from ..lookups import Lookups
|
from ..lookups import Lookups
|
||||||
|
from ..vectors import Vectors
|
||||||
from ..errors import Errors
|
from ..errors import Errors
|
||||||
from ..schemas import ConfigSchemaTraining, ConfigSchemaInit, ConfigSchemaPretrain
|
from ..schemas import ConfigSchemaTraining, ConfigSchemaInit, ConfigSchemaPretrain
|
||||||
from ..util import registry, load_model_from_config, resolve_dot_names
|
from ..util import registry, load_model_from_config, resolve_dot_names
|
||||||
|
@ -49,8 +55,8 @@ def init_nlp(config: Config, *, use_gpu: int = -1, silent: bool = True) -> Langu
|
||||||
msg.info(f"Resuming training for: {resume_components}")
|
msg.info(f"Resuming training for: {resume_components}")
|
||||||
nlp.resume_training(sgd=optimizer)
|
nlp.resume_training(sgd=optimizer)
|
||||||
with nlp.select_pipes(disable=[*frozen_components, *resume_components]):
|
with nlp.select_pipes(disable=[*frozen_components, *resume_components]):
|
||||||
nlp.initialize(lambda: train_corpus(nlp), sgd=optimizer)
|
nlp.initialize(lambda: train_corpus(nlp), sgd=optimizer, settings=I)
|
||||||
msg.good(f"Initialized pipeline components")
|
msg.good("Initialized pipeline components")
|
||||||
# Verify the config after calling 'initialize' to ensure labels
|
# Verify the config after calling 'initialize' to ensure labels
|
||||||
# are properly initialized
|
# are properly initialized
|
||||||
verify_config(nlp)
|
verify_config(nlp)
|
||||||
|
@ -103,7 +109,7 @@ def init_vocab(
|
||||||
|
|
||||||
|
|
||||||
def load_vectors_into_model(
|
def load_vectors_into_model(
|
||||||
nlp: "Language", name: Union[str, Path], *, add_strings: bool = True
|
nlp: Language, name: Union[str, Path], *, add_strings: bool = True
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Load word vectors from an installed model or path into a model instance."""
|
"""Load word vectors from an installed model or path into a model instance."""
|
||||||
try:
|
try:
|
||||||
|
@ -202,3 +208,104 @@ def get_sourced_components(config: Union[Dict[str, Any], Config]) -> List[str]:
|
||||||
for name, cfg in config.get("components", {}).items()
|
for name, cfg in config.get("components", {}).items()
|
||||||
if "factory" not in cfg and "source" in cfg
|
if "factory" not in cfg and "source" in cfg
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def convert_vectors(
|
||||||
|
nlp: Language,
|
||||||
|
vectors_loc: Optional[Path],
|
||||||
|
*,
|
||||||
|
truncate: int,
|
||||||
|
prune: int,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
silent: bool = True,
|
||||||
|
) -> None:
|
||||||
|
msg = Printer(no_print=silent)
|
||||||
|
vectors_loc = ensure_path(vectors_loc)
|
||||||
|
if vectors_loc and vectors_loc.parts[-1].endswith(".npz"):
|
||||||
|
nlp.vocab.vectors = Vectors(data=numpy.load(vectors_loc.open("rb")))
|
||||||
|
for lex in nlp.vocab:
|
||||||
|
if lex.rank and lex.rank != OOV_RANK:
|
||||||
|
nlp.vocab.vectors.add(lex.orth, row=lex.rank)
|
||||||
|
else:
|
||||||
|
if vectors_loc:
|
||||||
|
with msg.loading(f"Reading vectors from {vectors_loc}"):
|
||||||
|
vectors_data, vector_keys = read_vectors(vectors_loc, truncate)
|
||||||
|
msg.good(f"Loaded vectors from {vectors_loc}")
|
||||||
|
else:
|
||||||
|
vectors_data, vector_keys = (None, None)
|
||||||
|
if vector_keys is not None:
|
||||||
|
for word in vector_keys:
|
||||||
|
if word not in nlp.vocab:
|
||||||
|
nlp.vocab[word]
|
||||||
|
if vectors_data is not None:
|
||||||
|
nlp.vocab.vectors = Vectors(data=vectors_data, keys=vector_keys)
|
||||||
|
if name is None:
|
||||||
|
# TODO: Is this correct? Does this matter?
|
||||||
|
nlp.vocab.vectors.name = f"{nlp.meta['lang']}_{nlp.meta['name']}.vectors"
|
||||||
|
else:
|
||||||
|
nlp.vocab.vectors.name = name
|
||||||
|
nlp.meta["vectors"]["name"] = nlp.vocab.vectors.name
|
||||||
|
if prune >= 1:
|
||||||
|
nlp.vocab.prune_vectors(prune)
|
||||||
|
msg.good(f"Successfully converted {len(nlp.vocab.vectors)} vectors")
|
||||||
|
|
||||||
|
|
||||||
|
def read_vectors(vectors_loc: Path, truncate_vectors: int):
|
||||||
|
f = open_file(vectors_loc)
|
||||||
|
f = ensure_shape(f)
|
||||||
|
shape = tuple(int(size) for size in next(f).split())
|
||||||
|
if truncate_vectors >= 1:
|
||||||
|
shape = (truncate_vectors, shape[1])
|
||||||
|
vectors_data = numpy.zeros(shape=shape, dtype="f")
|
||||||
|
vectors_keys = []
|
||||||
|
for i, line in enumerate(tqdm.tqdm(f)):
|
||||||
|
line = line.rstrip()
|
||||||
|
pieces = line.rsplit(" ", vectors_data.shape[1])
|
||||||
|
word = pieces.pop(0)
|
||||||
|
if len(pieces) != vectors_data.shape[1]:
|
||||||
|
raise ValueError(Errors.E094.format(line_num=i, loc=vectors_loc))
|
||||||
|
vectors_data[i] = numpy.asarray(pieces, dtype="f")
|
||||||
|
vectors_keys.append(word)
|
||||||
|
if i == truncate_vectors - 1:
|
||||||
|
break
|
||||||
|
return vectors_data, vectors_keys
|
||||||
|
|
||||||
|
|
||||||
|
def open_file(loc: Union[str, Path]) -> IO:
|
||||||
|
"""Handle .gz, .tar.gz or unzipped files"""
|
||||||
|
loc = ensure_path(loc)
|
||||||
|
if tarfile.is_tarfile(str(loc)):
|
||||||
|
return tarfile.open(str(loc), "r:gz")
|
||||||
|
elif loc.parts[-1].endswith("gz"):
|
||||||
|
return (line.decode("utf8") for line in gzip.open(str(loc), "r"))
|
||||||
|
elif loc.parts[-1].endswith("zip"):
|
||||||
|
zip_file = zipfile.ZipFile(str(loc))
|
||||||
|
names = zip_file.namelist()
|
||||||
|
file_ = zip_file.open(names[0])
|
||||||
|
return (line.decode("utf8") for line in file_)
|
||||||
|
else:
|
||||||
|
return loc.open("r", encoding="utf8")
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_shape(lines):
|
||||||
|
"""Ensure that the first line of the data is the vectors shape.
|
||||||
|
If it's not, we read in the data and output the shape as the first result,
|
||||||
|
so that the reader doesn't have to deal with the problem.
|
||||||
|
"""
|
||||||
|
first_line = next(lines)
|
||||||
|
try:
|
||||||
|
shape = tuple(int(size) for size in first_line.split())
|
||||||
|
except ValueError:
|
||||||
|
shape = None
|
||||||
|
if shape is not None:
|
||||||
|
# All good, give the data
|
||||||
|
yield first_line
|
||||||
|
yield from lines
|
||||||
|
else:
|
||||||
|
# Figure out the shape, make it the first value, and then give the
|
||||||
|
# rest of the data.
|
||||||
|
width = len(first_line.split()) - 1
|
||||||
|
captured = [first_line] + list(lines)
|
||||||
|
length = len(captured)
|
||||||
|
yield f"{length} {width}"
|
||||||
|
yield from captured
|
||||||
|
|
Loading…
Reference in New Issue
Block a user