Upd initialize args

This commit is contained in:
Matthew Honnibal 2020-09-29 12:08:37 +02:00
commit 9c8b2524fe
5 changed files with 251 additions and 11 deletions

View File

@ -5,11 +5,35 @@ from wasabi import msg
import typer import typer
from .. import util from .. import util
from ..training.initialize import init_nlp from ..training.initialize import init_nlp, convert_vectors
from ._util import init_cli, Arg, Opt, parse_config_overrides, show_validation_error from ._util import init_cli, Arg, Opt, parse_config_overrides, show_validation_error
from ._util import import_code, setup_gpu from ._util import import_code, setup_gpu
@init_cli.command("vectors")
def init_vectors_cli(
# fmt: off
lang: str = Arg(..., help="The language of the nlp object to create"),
vectors_loc: Path = Arg(..., help="Vectors file in Word2Vec format", exists=True),
output_dir: Path = Arg(..., help="Pipeline output directory"),
prune: int = Opt(-1, "--prune", "-p", help="Optional number of vectors to prune to"),
truncate: int = Opt(0, "--truncate", "-t", help="Optional number of vectors to truncate to when reading in vectors file"),
name: Optional[str] = Opt(None, "--name", "-n", help="Optional name for the word vectors, e.g. en_core_web_lg.vectors"),
# fmt: on
):
msg.info(f"Creating blank nlp object for language '{lang}'")
nlp = util.get_lang_class(lang)()
convert_vectors(
nlp, vectors_loc, truncate=truncate, prune=prune, name=name, silent=False
)
nlp.to_disk(output_dir)
msg.good(
"Saved nlp object with vectors to output directory. You can now use the "
"path to it in your config as the 'vectors' setting in [initialize.vocab].",
output_dir,
)
@init_cli.command( @init_cli.command(
"nlp", "nlp",
context_settings={"allow_extra_args": True, "ignore_unknown_options": True}, context_settings={"allow_extra_args": True, "ignore_unknown_options": True},

View File

@ -27,7 +27,7 @@ from .lang.punctuation import TOKENIZER_INFIXES
from .tokens import Doc from .tokens import Doc
from .tokenizer import Tokenizer from .tokenizer import Tokenizer
from .errors import Errors, Warnings from .errors import Errors, Warnings
from .schemas import ConfigSchema, ConfigSchemaNlp from .schemas import ConfigSchema, ConfigSchemaNlp, validate_init_settings
from .git_info import GIT_VERSION from .git_info import GIT_VERSION
from . import util from . import util
from . import about from . import about
@ -1161,8 +1161,10 @@ class Language:
def initialize( def initialize(
self, self,
get_examples: Optional[Callable[[], Iterable[Example]]] = None, get_examples: Optional[Callable[[], Iterable[Example]]] = None,
sgd: Optional[Optimizer]=None *,
) -> None: settings: Dict[str, Dict[str, Any]] = SimpleFrozenDict(),
sgd: Optional[Optimizer] = None,
) -> Optimizer:
"""Initialize the pipe for training, using data examples if available. """Initialize the pipe for training, using data examples if available.
get_examples (Callable[[], Iterable[Example]]): Optional function that get_examples (Callable[[], Iterable[Example]]): Optional function that
@ -1200,10 +1202,28 @@ class Language:
if self.vocab.vectors.data.shape[1] >= 1: if self.vocab.vectors.data.shape[1] >= 1:
ops = get_current_ops() ops = get_current_ops()
self.vocab.vectors.data = ops.asarray(self.vocab.vectors.data) self.vocab.vectors.data = ops.asarray(self.vocab.vectors.data)
self._optimizer = sgd
if hasattr(self.tokenizer, "initialize"):
tok_settings = settings.get("tokenizer", {})
tok_settings = validate_init_settings(
self.tokenizer.initialize,
tok_settings,
section="tokenizer",
name="tokenizer",
)
self.tokenizer.initialize(get_examples, nlp=self, **tok_settings)
proc_settings = settings.get("components", {})
for name, proc in self.pipeline: for name, proc in self.pipeline:
if hasattr(proc, "initialize"): if hasattr(proc, "initialize"):
p_settings = proc_settings.get(name, {})
p_settings = validate_init_settings(
proc.initialize, p_settings, section="components", name=name
)
proc.initialize( proc.initialize(
get_examples, pipeline=self.pipeline get_examples, pipeline=self.pipeline
get_examples,
pipeline=self.pipeline,
**p_settings,
) )
self._link_components() self._link_components()
if sgd is not None: if sgd is not None:

View File

@ -1,4 +1,4 @@
# cython: infer_types=True, cdivision=True, boundscheck=False # cython: infer_types=True, cdivision=True, boundscheck=False, binding=True
from __future__ import print_function from __future__ import print_function
from cymem.cymem cimport Pool from cymem.cymem cimport Pool
cimport numpy as np cimport numpy as np

View File

@ -1,11 +1,13 @@
from typing import Dict, List, Union, Optional, Any, Callable, Type, Tuple from typing import Dict, List, Union, Optional, Any, Callable, Type, Tuple
from typing import Iterable, TypeVar, TYPE_CHECKING from typing import Iterable, TypeVar, TYPE_CHECKING
from enum import Enum from enum import Enum
from pydantic import BaseModel, Field, ValidationError, validator from pydantic import BaseModel, Field, ValidationError, validator, create_model
from pydantic import StrictStr, StrictInt, StrictFloat, StrictBool from pydantic import StrictStr, StrictInt, StrictFloat, StrictBool
from pydantic.main import ModelMetaclass
from thinc.api import Optimizer, ConfigValidationError
from thinc.config import Promise from thinc.config import Promise
from collections import defaultdict from collections import defaultdict
from thinc.api import Optimizer import inspect
from .attrs import NAMES from .attrs import NAMES
from .lookups import Lookups from .lookups import Lookups
@ -43,6 +45,93 @@ def validate(schema: Type[BaseModel], obj: Dict[str, Any]) -> List[str]:
return [f"[{loc}] {', '.join(msg)}" for loc, msg in data.items()] return [f"[{loc}] {', '.join(msg)}" for loc, msg in data.items()]
# Initialization
class ArgSchemaConfig:
extra = "forbid"
arbitrary_types_allowed = True
class ArgSchemaConfigExtra:
extra = "forbid"
arbitrary_types_allowed = True
def get_arg_model(
func: Callable,
*,
exclude: Iterable[str] = tuple(),
name: str = "ArgModel",
strict: bool = True,
) -> ModelMetaclass:
"""Generate a pydantic model for function arguments.
func (Callable): The function to generate the schema for.
exclude (Iterable[str]): Parameter names to ignore.
name (str): Name of created model class.
strict (bool): Don't allow extra arguments if no variable keyword arguments
are allowed on the function.
RETURNS (ModelMetaclass): A pydantic model.
"""
sig_args = {}
try:
sig = inspect.signature(func)
except ValueError:
# Typically happens if the method is part of a Cython module without
# binding=True. Here we just use an empty model that allows everything.
return create_model(name, __config__=ArgSchemaConfigExtra)
has_variable = False
for param in sig.parameters.values():
if param.name in exclude:
continue
if param.kind == param.VAR_KEYWORD:
# The function allows variable keyword arguments so we shouldn't
# include **kwargs etc. in the schema and switch to non-strict
# mode and pass through all other values
has_variable = True
continue
# If no annotation is specified assume it's anything
annotation = param.annotation if param.annotation != param.empty else Any
# If no default value is specified assume that it's required
default = param.default if param.default != param.empty else ...
sig_args[param.name] = (annotation, default)
is_strict = strict and not has_variable
sig_args["__config__"] = ArgSchemaConfig if is_strict else ArgSchemaConfigExtra
return create_model(name, **sig_args)
def validate_init_settings(
func: Callable,
settings: Dict[str, Any],
*,
section: Optional[str] = None,
name: str = "",
exclude: Iterable[str] = ("get_examples", "nlp", "pipeline", "sgd"),
) -> Dict[str, Any]:
"""Validate initialization settings against the expected arguments in
the method signature. Will parse values if possible (e.g. int to string)
and return the updated settings dict. Will raise a ConfigValidationError
if types don't match or required values are missing.
func (Callable): The initialize method of a given component etc.
settings (Dict[str, Any]): The settings from the repsective [initialize] block.
section (str): Initialize section, for error message.
name (str): Name of the block in the section.
exclude (Iterable[str]): Parameter names to exclude from schema.
RETURNS (Dict[str, Any]): The validated settings.
"""
schema = get_arg_model(func, exclude=exclude, name="InitArgModel")
try:
return schema(**settings).dict()
except ValidationError as e:
block = "initialize" if not section else f"initialize.{section}"
title = f"Error validating initialization settings in [{block}]"
raise ConfigValidationError(
title=title, errors=e.errors(), config=settings, parent=name,
) from None
# Matcher token patterns # Matcher token patterns

View File

@ -1,13 +1,19 @@
from typing import Union, Dict, Optional, Any, List from typing import Union, Dict, Optional, Any, List, IO
from thinc.api import Config, fix_random_seed, set_gpu_allocator from thinc.api import Config, fix_random_seed, set_gpu_allocator
from thinc.api import ConfigValidationError from thinc.api import ConfigValidationError
from pathlib import Path from pathlib import Path
from wasabi import Printer from wasabi import Printer
import srsly import srsly
import numpy
import tarfile
import gzip
import zipfile
import tqdm
from .loop import create_before_to_disk_callback from .loop import create_before_to_disk_callback
from ..language import Language from ..language import Language
from ..lookups import Lookups from ..lookups import Lookups
from ..vectors import Vectors
from ..errors import Errors from ..errors import Errors
from ..schemas import ConfigSchemaTraining, ConfigSchemaInit, ConfigSchemaPretrain from ..schemas import ConfigSchemaTraining, ConfigSchemaInit, ConfigSchemaPretrain
from ..util import registry, load_model_from_config, resolve_dot_names from ..util import registry, load_model_from_config, resolve_dot_names
@ -49,8 +55,8 @@ def init_nlp(config: Config, *, use_gpu: int = -1, silent: bool = True) -> Langu
msg.info(f"Resuming training for: {resume_components}") msg.info(f"Resuming training for: {resume_components}")
nlp.resume_training(sgd=optimizer) nlp.resume_training(sgd=optimizer)
with nlp.select_pipes(disable=[*frozen_components, *resume_components]): with nlp.select_pipes(disable=[*frozen_components, *resume_components]):
nlp.initialize(lambda: train_corpus(nlp), sgd=optimizer) nlp.initialize(lambda: train_corpus(nlp), sgd=optimizer, settings=I)
msg.good(f"Initialized pipeline components") msg.good("Initialized pipeline components")
# Verify the config after calling 'initialize' to ensure labels # Verify the config after calling 'initialize' to ensure labels
# are properly initialized # are properly initialized
verify_config(nlp) verify_config(nlp)
@ -103,7 +109,7 @@ def init_vocab(
def load_vectors_into_model( def load_vectors_into_model(
nlp: "Language", name: Union[str, Path], *, add_strings: bool = True nlp: Language, name: Union[str, Path], *, add_strings: bool = True
) -> None: ) -> None:
"""Load word vectors from an installed model or path into a model instance.""" """Load word vectors from an installed model or path into a model instance."""
try: try:
@ -202,3 +208,104 @@ def get_sourced_components(config: Union[Dict[str, Any], Config]) -> List[str]:
for name, cfg in config.get("components", {}).items() for name, cfg in config.get("components", {}).items()
if "factory" not in cfg and "source" in cfg if "factory" not in cfg and "source" in cfg
] ]
def convert_vectors(
nlp: Language,
vectors_loc: Optional[Path],
*,
truncate: int,
prune: int,
name: Optional[str] = None,
silent: bool = True,
) -> None:
msg = Printer(no_print=silent)
vectors_loc = ensure_path(vectors_loc)
if vectors_loc and vectors_loc.parts[-1].endswith(".npz"):
nlp.vocab.vectors = Vectors(data=numpy.load(vectors_loc.open("rb")))
for lex in nlp.vocab:
if lex.rank and lex.rank != OOV_RANK:
nlp.vocab.vectors.add(lex.orth, row=lex.rank)
else:
if vectors_loc:
with msg.loading(f"Reading vectors from {vectors_loc}"):
vectors_data, vector_keys = read_vectors(vectors_loc, truncate)
msg.good(f"Loaded vectors from {vectors_loc}")
else:
vectors_data, vector_keys = (None, None)
if vector_keys is not None:
for word in vector_keys:
if word not in nlp.vocab:
nlp.vocab[word]
if vectors_data is not None:
nlp.vocab.vectors = Vectors(data=vectors_data, keys=vector_keys)
if name is None:
# TODO: Is this correct? Does this matter?
nlp.vocab.vectors.name = f"{nlp.meta['lang']}_{nlp.meta['name']}.vectors"
else:
nlp.vocab.vectors.name = name
nlp.meta["vectors"]["name"] = nlp.vocab.vectors.name
if prune >= 1:
nlp.vocab.prune_vectors(prune)
msg.good(f"Successfully converted {len(nlp.vocab.vectors)} vectors")
def read_vectors(vectors_loc: Path, truncate_vectors: int):
f = open_file(vectors_loc)
f = ensure_shape(f)
shape = tuple(int(size) for size in next(f).split())
if truncate_vectors >= 1:
shape = (truncate_vectors, shape[1])
vectors_data = numpy.zeros(shape=shape, dtype="f")
vectors_keys = []
for i, line in enumerate(tqdm.tqdm(f)):
line = line.rstrip()
pieces = line.rsplit(" ", vectors_data.shape[1])
word = pieces.pop(0)
if len(pieces) != vectors_data.shape[1]:
raise ValueError(Errors.E094.format(line_num=i, loc=vectors_loc))
vectors_data[i] = numpy.asarray(pieces, dtype="f")
vectors_keys.append(word)
if i == truncate_vectors - 1:
break
return vectors_data, vectors_keys
def open_file(loc: Union[str, Path]) -> IO:
"""Handle .gz, .tar.gz or unzipped files"""
loc = ensure_path(loc)
if tarfile.is_tarfile(str(loc)):
return tarfile.open(str(loc), "r:gz")
elif loc.parts[-1].endswith("gz"):
return (line.decode("utf8") for line in gzip.open(str(loc), "r"))
elif loc.parts[-1].endswith("zip"):
zip_file = zipfile.ZipFile(str(loc))
names = zip_file.namelist()
file_ = zip_file.open(names[0])
return (line.decode("utf8") for line in file_)
else:
return loc.open("r", encoding="utf8")
def ensure_shape(lines):
"""Ensure that the first line of the data is the vectors shape.
If it's not, we read in the data and output the shape as the first result,
so that the reader doesn't have to deal with the problem.
"""
first_line = next(lines)
try:
shape = tuple(int(size) for size in first_line.split())
except ValueError:
shape = None
if shape is not None:
# All good, give the data
yield first_line
yield from lines
else:
# Figure out the shape, make it the first value, and then give the
# rest of the data.
width = len(first_line.split()) - 1
captured = [first_line] + list(lines)
length = len(captured)
yield f"{length} {width}"
yield from captured