spaCy/spacy/training/initialize.py
Matthew Honnibal 13b1605ee6 Add init script
2020-09-28 01:08:49 +02:00

379 lines
14 KiB
Python

from pathlib import Path
from typing import Dict
from ._util import app, init_cli, Arg, Opt
from ..vectors import Vectors
from ..errors import Errors, Warnings
from ..language import Language
from ..util import ensure_path, get_lang_class, load_model, OOV_RANK
try:
import ftfy
except ImportError:
ftfy = None
def must_initialize(init_path: Path, config_path: Path, overrides: Dict) -> bool:
config = util.load_config(config_path, overrides=overrides)
if not init_path.exists():
return True
elif not (init_path / "config.cfg").exists():
return True
else:
init_cfg = util.load_config(init_path / "config.cfg", interpolate=True)
if config.to_str() != init_cfg.to_str():
return True
else:
return False
def init_pipeline(config: Config, use_gpu: int=-1):
raw_config = config
config = raw_config.interpolate()
if config["training"]["seed"] is not None:
fix_random_seed(config["training"]["seed"])
allocator = config["training"]["gpu_allocator"]
if use_gpu >= 0 and allocator:
set_gpu_allocator(allocator)
# Use original config here before it's resolved to functions
sourced_components = get_sourced_components(config)
with show_validation_error(config_path):
nlp = util.load_model_from_config(raw_config)
# Resolve all training-relevant sections using the filled nlp config
T = registry.resolve(
config["training"],
schema=TrainingSchema,
validate=validate,
)
# TODO: It might not be 'corpora'
corpora = registry.resolve(config["corpora"], validate=True)
raw_text, tag_map, morph_rules, weights_data = load_from_paths(config)
util.load_vocab_data_into_model(nlp, lookups=T["lookups"])
if T["vectors"] is not None:
add_vectors(nlp, T["vectors"])
score_weights = T["score_weights"]
optimizer = T["optimizer"]
train_corpus = dot_to_object({"corpora": corpora}, T["train_corpus"])
dev_corpus = dot_to_object({"corpora": corpora}, T["dev_corpus"])
batcher = T["batcher"]
train_logger = T["logger"]
before_to_disk = create_before_to_disk_callback(T["before_to_disk"])
# Components that shouldn't be updated during training
frozen_components = T["frozen_components"]
# Sourced components that require resume_training
resume_components = [p for p in sourced_components if p not in frozen_components]
msg.info(f"Pipeline: {nlp.pipe_names}")
if resume_components:
with nlp.select_pipes(enable=resume_components):
msg.info(f"Resuming training for: {resume_components}")
nlp.resume_training(sgd=optimizer)
with nlp.select_pipes(disable=[*frozen_components, *resume_components]):
nlp.begin_training(lambda: train_corpus(nlp), sgd=optimizer)
# Verify the config after calling 'begin_training' to ensure labels
# are properly initialized
verify_config(nlp)
if tag_map:
# Replace tag map with provided mapping
nlp.vocab.morphology.load_tag_map(tag_map)
if morph_rules:
# Load morph rules
nlp.vocab.morphology.load_morph_exceptions(morph_rules)
# Load pretrained tok2vec weights - cf. CLI command 'pretrain'
if weights_data is not None:
tok2vec_component = C["pretraining"]["component"]
if tok2vec_component is None:
msg.fail(
f"To use pretrained tok2vec weights, [pretraining.component] "
f"needs to specify the component that should load them.",
exits=1,
)
layer = nlp.get_pipe(tok2vec_component).model
tok2vec_layer = C["pretraining"]["layer"]
if tok2vec_layer:
layer = layer.get_ref(tok2vec_layer)
layer.from_bytes(weights_data)
msg.info(f"Loaded pretrained weights into component '{tok2vec_component}'")
return nlp
def init_vocab(
lang: str,
output_dir: Path,
freqs_loc: Optional[Path] = None,
clusters_loc: Optional[Path] = None,
jsonl_loc: Optional[Path] = None,
vectors_loc: Optional[Path] = None,
prune_vectors: int = -1,
truncate_vectors: int = 0,
vectors_name: Optional[str] = None,
model_name: Optional[str] = None,
base_model: Optional[str] = None,
silent: bool = True,
) -> Language:
msg = Printer(no_print=silent, pretty=not silent)
if jsonl_loc is not None:
if freqs_loc is not None or clusters_loc is not None:
settings = ["-j"]
if freqs_loc:
settings.append("-f")
if clusters_loc:
settings.append("-c")
msg.warn(
"Incompatible arguments",
"The -f and -c arguments are deprecated, and not compatible "
"with the -j argument, which should specify the same "
"information. Either merge the frequencies and clusters data "
"into the JSONL-formatted file (recommended), or use only the "
"-f and -c files, without the other lexical attributes.",
)
jsonl_loc = ensure_path(jsonl_loc)
lex_attrs = srsly.read_jsonl(jsonl_loc)
else:
clusters_loc = ensure_path(clusters_loc)
freqs_loc = ensure_path(freqs_loc)
if freqs_loc is not None and not freqs_loc.exists():
msg.fail("Can't find words frequencies file", freqs_loc, exits=1)
lex_attrs = read_attrs_from_deprecated(msg, freqs_loc, clusters_loc)
with msg.loading("Creating blank pipeline..."):
nlp = create_model(lang, lex_attrs, name=model_name, base_model=base_model)
msg.good("Successfully created blank pipeline")
if vectors_loc is not None:
add_vectors(
msg, nlp, vectors_loc, truncate_vectors, prune_vectors, vectors_name
)
vec_added = len(nlp.vocab.vectors)
lex_added = len(nlp.vocab)
msg.good(
"Sucessfully compiled vocab", f"{lex_added} entries, {vec_added} vectors",
)
if not output_dir.exists():
output_dir.mkdir()
nlp.to_disk(output_dir)
return nlp
def open_file(loc: Union[str, Path]) -> IO:
"""Handle .gz, .tar.gz or unzipped files"""
loc = ensure_path(loc)
if tarfile.is_tarfile(str(loc)):
return tarfile.open(str(loc), "r:gz")
elif loc.parts[-1].endswith("gz"):
return (line.decode("utf8") for line in gzip.open(str(loc), "r"))
elif loc.parts[-1].endswith("zip"):
zip_file = zipfile.ZipFile(str(loc))
names = zip_file.namelist()
file_ = zip_file.open(names[0])
return (line.decode("utf8") for line in file_)
else:
return loc.open("r", encoding="utf8")
def read_attrs_from_deprecated(
msg: Printer, freqs_loc: Optional[Path], clusters_loc: Optional[Path]
) -> List[Dict[str, Any]]:
if freqs_loc is not None:
with msg.loading("Counting frequencies..."):
probs, _ = read_freqs(freqs_loc)
msg.good("Counted frequencies")
else:
probs, _ = ({}, DEFAULT_OOV_PROB) # noqa: F841
if clusters_loc:
with msg.loading("Reading clusters..."):
clusters = read_clusters(clusters_loc)
msg.good("Read clusters")
else:
clusters = {}
lex_attrs = []
sorted_probs = sorted(probs.items(), key=lambda item: item[1], reverse=True)
if len(sorted_probs):
for i, (word, prob) in tqdm(enumerate(sorted_probs)):
attrs = {"orth": word, "id": i, "prob": prob}
# Decode as a little-endian string, so that we can do & 15 to get
# the first 4 bits. See _parse_features.pyx
if word in clusters:
attrs["cluster"] = int(clusters[word][::-1], 2)
else:
attrs["cluster"] = 0
lex_attrs.append(attrs)
return lex_attrs
def create_model(
lang: str,
lex_attrs: List[Dict[str, Any]],
name: Optional[str] = None,
base_model: Optional[Union[str, Path]] = None,
) -> Language:
if base_model:
nlp = load_model(base_model)
# keep the tokenizer but remove any existing pipeline components due to
# potentially conflicting vectors
for pipe in nlp.pipe_names:
nlp.remove_pipe(pipe)
else:
lang_class = get_lang_class(lang)
nlp = lang_class()
for lexeme in nlp.vocab:
lexeme.rank = OOV_RANK
for attrs in lex_attrs:
if "settings" in attrs:
continue
lexeme = nlp.vocab[attrs["orth"]]
lexeme.set_attrs(**attrs)
if len(nlp.vocab):
oov_prob = min(lex.prob for lex in nlp.vocab) - 1
else:
oov_prob = DEFAULT_OOV_PROB
nlp.vocab.cfg.update({"oov_prob": oov_prob})
if name:
nlp.meta["name"] = name
return nlp
def add_vectors(
msg: Printer,
nlp: Language,
vectors_loc: Optional[Path],
truncate_vectors: int,
prune_vectors: int,
name: Optional[str] = None,
) -> None:
vectors_loc = ensure_path(vectors_loc)
if vectors_loc and vectors_loc.parts[-1].endswith(".npz"):
nlp.vocab.vectors = Vectors(data=numpy.load(vectors_loc.open("rb")))
for lex in nlp.vocab:
if lex.rank and lex.rank != OOV_RANK:
nlp.vocab.vectors.add(lex.orth, row=lex.rank)
else:
if vectors_loc:
with msg.loading(f"Reading vectors from {vectors_loc}"):
vectors_data, vector_keys = read_vectors(
msg, vectors_loc, truncate_vectors
)
msg.good(f"Loaded vectors from {vectors_loc}")
else:
vectors_data, vector_keys = (None, None)
if vector_keys is not None:
for word in vector_keys:
if word not in nlp.vocab:
nlp.vocab[word]
if vectors_data is not None:
nlp.vocab.vectors = Vectors(data=vectors_data, keys=vector_keys)
if name is None:
# TODO: Is this correct? Does this matter?
nlp.vocab.vectors.name = f"{nlp.meta['lang']}_{nlp.meta['name']}.vectors"
else:
nlp.vocab.vectors.name = name
nlp.meta["vectors"]["name"] = nlp.vocab.vectors.name
if prune_vectors >= 1:
nlp.vocab.prune_vectors(prune_vectors)
def read_vectors(msg: Printer, vectors_loc: Path, truncate_vectors: int):
f = open_file(vectors_loc)
f = ensure_shape(f)
shape = tuple(int(size) for size in next(f).split())
if truncate_vectors >= 1:
shape = (truncate_vectors, shape[1])
vectors_data = numpy.zeros(shape=shape, dtype="f")
vectors_keys = []
for i, line in enumerate(tqdm(f)):
line = line.rstrip()
pieces = line.rsplit(" ", vectors_data.shape[1])
word = pieces.pop(0)
if len(pieces) != vectors_data.shape[1]:
msg.fail(Errors.E094.format(line_num=i, loc=vectors_loc), exits=1)
vectors_data[i] = numpy.asarray(pieces, dtype="f")
vectors_keys.append(word)
if i == truncate_vectors - 1:
break
return vectors_data, vectors_keys
def ensure_shape(lines):
"""Ensure that the first line of the data is the vectors shape.
If it's not, we read in the data and output the shape as the first result,
so that the reader doesn't have to deal with the problem.
"""
first_line = next(lines)
try:
shape = tuple(int(size) for size in first_line.split())
except ValueError:
shape = None
if shape is not None:
# All good, give the data
yield first_line
yield from lines
else:
# Figure out the shape, make it the first value, and then give the
# rest of the data.
width = len(first_line.split()) - 1
captured = [first_line] + list(lines)
length = len(captured)
yield f"{length} {width}"
yield from captured
def read_freqs(
freqs_loc: Path, max_length: int = 100, min_doc_freq: int = 5, min_freq: int = 50
):
counts = PreshCounter()
total = 0
with freqs_loc.open() as f:
for i, line in enumerate(f):
freq, doc_freq, key = line.rstrip().split("\t", 2)
freq = int(freq)
counts.inc(i + 1, freq)
total += freq
counts.smooth()
log_total = math.log(total)
probs = {}
with freqs_loc.open() as f:
for line in tqdm(f):
freq, doc_freq, key = line.rstrip().split("\t", 2)
doc_freq = int(doc_freq)
freq = int(freq)
if doc_freq >= min_doc_freq and freq >= min_freq and len(key) < max_length:
try:
word = literal_eval(key)
except SyntaxError:
# Take odd strings literally.
word = literal_eval(f"'{key}'")
smooth_count = counts.smoother(int(freq))
probs[word] = math.log(smooth_count) - log_total
oov_prob = math.log(counts.smoother(0)) - log_total
return probs, oov_prob
def read_clusters(clusters_loc: Path) -> dict:
clusters = {}
if ftfy is None:
warnings.warn(Warnings.W004)
with clusters_loc.open() as f:
for line in tqdm(f):
try:
cluster, word, freq = line.split()
if ftfy is not None:
word = ftfy.fix_text(word)
except ValueError:
continue
# If the clusterer has only seen the word a few times, its
# cluster is unreliable.
if int(freq) >= 3:
clusters[word] = cluster
else:
clusters[word] = "0"
# Expand clusters with re-casing
for word, cluster in list(clusters.items()):
if word.lower() not in clusters:
clusters[word.lower()] = cluster
if word.title() not in clusters:
clusters[word.title()] = cluster
if word.upper() not in clusters:
clusters[word.upper()] = cluster
return clusters