diff --git a/spacy/training/initialize.py b/spacy/training/initialize.py new file mode 100644 index 000000000..07bbced8d --- /dev/null +++ b/spacy/training/initialize.py @@ -0,0 +1,378 @@ +from pathlib import Path +from typing import Dict +from ._util import app, init_cli, Arg, Opt +from ..vectors import Vectors +from ..errors import Errors, Warnings +from ..language import Language +from ..util import ensure_path, get_lang_class, load_model, OOV_RANK + +try: + import ftfy +except ImportError: + ftfy = None + + +def must_initialize(init_path: Path, config_path: Path, overrides: Dict) -> bool: + config = util.load_config(config_path, overrides=overrides) + if not init_path.exists(): + return True + elif not (init_path / "config.cfg").exists(): + return True + else: + init_cfg = util.load_config(init_path / "config.cfg", interpolate=True) + if config.to_str() != init_cfg.to_str(): + return True + else: + return False + + +def init_pipeline(config: Config, use_gpu: int=-1): + raw_config = config + config = raw_config.interpolate() + if config["training"]["seed"] is not None: + fix_random_seed(config["training"]["seed"]) + allocator = config["training"]["gpu_allocator"] + if use_gpu >= 0 and allocator: + set_gpu_allocator(allocator) + # Use original config here before it's resolved to functions + sourced_components = get_sourced_components(config) + with show_validation_error(config_path): + nlp = util.load_model_from_config(raw_config) + # Resolve all training-relevant sections using the filled nlp config + T = registry.resolve( + config["training"], + schema=TrainingSchema, + validate=validate, + ) + # TODO: It might not be 'corpora' + corpora = registry.resolve(config["corpora"], validate=True) + raw_text, tag_map, morph_rules, weights_data = load_from_paths(config) + util.load_vocab_data_into_model(nlp, lookups=T["lookups"]) + if T["vectors"] is not None: + add_vectors(nlp, T["vectors"]) + score_weights = T["score_weights"] + optimizer = T["optimizer"] + train_corpus = dot_to_object({"corpora": corpora}, T["train_corpus"]) + dev_corpus = dot_to_object({"corpora": corpora}, T["dev_corpus"]) + batcher = T["batcher"] + train_logger = T["logger"] + before_to_disk = create_before_to_disk_callback(T["before_to_disk"]) + # Components that shouldn't be updated during training + frozen_components = T["frozen_components"] + # Sourced components that require resume_training + resume_components = [p for p in sourced_components if p not in frozen_components] + msg.info(f"Pipeline: {nlp.pipe_names}") + if resume_components: + with nlp.select_pipes(enable=resume_components): + msg.info(f"Resuming training for: {resume_components}") + nlp.resume_training(sgd=optimizer) + with nlp.select_pipes(disable=[*frozen_components, *resume_components]): + nlp.begin_training(lambda: train_corpus(nlp), sgd=optimizer) + # Verify the config after calling 'begin_training' to ensure labels + # are properly initialized + verify_config(nlp) + + if tag_map: + # Replace tag map with provided mapping + nlp.vocab.morphology.load_tag_map(tag_map) + if morph_rules: + # Load morph rules + nlp.vocab.morphology.load_morph_exceptions(morph_rules) + + # Load pretrained tok2vec weights - cf. CLI command 'pretrain' + if weights_data is not None: + tok2vec_component = C["pretraining"]["component"] + if tok2vec_component is None: + msg.fail( + f"To use pretrained tok2vec weights, [pretraining.component] " + f"needs to specify the component that should load them.", + exits=1, + ) + layer = nlp.get_pipe(tok2vec_component).model + tok2vec_layer = C["pretraining"]["layer"] + if tok2vec_layer: + layer = layer.get_ref(tok2vec_layer) + layer.from_bytes(weights_data) + msg.info(f"Loaded pretrained weights into component '{tok2vec_component}'") + return nlp + + +def init_vocab( + lang: str, + output_dir: Path, + freqs_loc: Optional[Path] = None, + clusters_loc: Optional[Path] = None, + jsonl_loc: Optional[Path] = None, + vectors_loc: Optional[Path] = None, + prune_vectors: int = -1, + truncate_vectors: int = 0, + vectors_name: Optional[str] = None, + model_name: Optional[str] = None, + base_model: Optional[str] = None, + silent: bool = True, +) -> Language: + msg = Printer(no_print=silent, pretty=not silent) + if jsonl_loc is not None: + if freqs_loc is not None or clusters_loc is not None: + settings = ["-j"] + if freqs_loc: + settings.append("-f") + if clusters_loc: + settings.append("-c") + msg.warn( + "Incompatible arguments", + "The -f and -c arguments are deprecated, and not compatible " + "with the -j argument, which should specify the same " + "information. Either merge the frequencies and clusters data " + "into the JSONL-formatted file (recommended), or use only the " + "-f and -c files, without the other lexical attributes.", + ) + jsonl_loc = ensure_path(jsonl_loc) + lex_attrs = srsly.read_jsonl(jsonl_loc) + else: + clusters_loc = ensure_path(clusters_loc) + freqs_loc = ensure_path(freqs_loc) + if freqs_loc is not None and not freqs_loc.exists(): + msg.fail("Can't find words frequencies file", freqs_loc, exits=1) + lex_attrs = read_attrs_from_deprecated(msg, freqs_loc, clusters_loc) + + with msg.loading("Creating blank pipeline..."): + nlp = create_model(lang, lex_attrs, name=model_name, base_model=base_model) + + msg.good("Successfully created blank pipeline") + if vectors_loc is not None: + add_vectors( + msg, nlp, vectors_loc, truncate_vectors, prune_vectors, vectors_name + ) + vec_added = len(nlp.vocab.vectors) + lex_added = len(nlp.vocab) + msg.good( + "Sucessfully compiled vocab", f"{lex_added} entries, {vec_added} vectors", + ) + if not output_dir.exists(): + output_dir.mkdir() + nlp.to_disk(output_dir) + return nlp + + +def open_file(loc: Union[str, Path]) -> IO: + """Handle .gz, .tar.gz or unzipped files""" + loc = ensure_path(loc) + if tarfile.is_tarfile(str(loc)): + return tarfile.open(str(loc), "r:gz") + elif loc.parts[-1].endswith("gz"): + return (line.decode("utf8") for line in gzip.open(str(loc), "r")) + elif loc.parts[-1].endswith("zip"): + zip_file = zipfile.ZipFile(str(loc)) + names = zip_file.namelist() + file_ = zip_file.open(names[0]) + return (line.decode("utf8") for line in file_) + else: + return loc.open("r", encoding="utf8") + + +def read_attrs_from_deprecated( + msg: Printer, freqs_loc: Optional[Path], clusters_loc: Optional[Path] +) -> List[Dict[str, Any]]: + if freqs_loc is not None: + with msg.loading("Counting frequencies..."): + probs, _ = read_freqs(freqs_loc) + msg.good("Counted frequencies") + else: + probs, _ = ({}, DEFAULT_OOV_PROB) # noqa: F841 + if clusters_loc: + with msg.loading("Reading clusters..."): + clusters = read_clusters(clusters_loc) + msg.good("Read clusters") + else: + clusters = {} + lex_attrs = [] + sorted_probs = sorted(probs.items(), key=lambda item: item[1], reverse=True) + if len(sorted_probs): + for i, (word, prob) in tqdm(enumerate(sorted_probs)): + attrs = {"orth": word, "id": i, "prob": prob} + # Decode as a little-endian string, so that we can do & 15 to get + # the first 4 bits. See _parse_features.pyx + if word in clusters: + attrs["cluster"] = int(clusters[word][::-1], 2) + else: + attrs["cluster"] = 0 + lex_attrs.append(attrs) + return lex_attrs + + +def create_model( + lang: str, + lex_attrs: List[Dict[str, Any]], + name: Optional[str] = None, + base_model: Optional[Union[str, Path]] = None, +) -> Language: + if base_model: + nlp = load_model(base_model) + # keep the tokenizer but remove any existing pipeline components due to + # potentially conflicting vectors + for pipe in nlp.pipe_names: + nlp.remove_pipe(pipe) + else: + lang_class = get_lang_class(lang) + nlp = lang_class() + for lexeme in nlp.vocab: + lexeme.rank = OOV_RANK + for attrs in lex_attrs: + if "settings" in attrs: + continue + lexeme = nlp.vocab[attrs["orth"]] + lexeme.set_attrs(**attrs) + if len(nlp.vocab): + oov_prob = min(lex.prob for lex in nlp.vocab) - 1 + else: + oov_prob = DEFAULT_OOV_PROB + nlp.vocab.cfg.update({"oov_prob": oov_prob}) + if name: + nlp.meta["name"] = name + return nlp + + +def add_vectors( + msg: Printer, + nlp: Language, + vectors_loc: Optional[Path], + truncate_vectors: int, + prune_vectors: int, + name: Optional[str] = None, +) -> None: + vectors_loc = ensure_path(vectors_loc) + if vectors_loc and vectors_loc.parts[-1].endswith(".npz"): + nlp.vocab.vectors = Vectors(data=numpy.load(vectors_loc.open("rb"))) + for lex in nlp.vocab: + if lex.rank and lex.rank != OOV_RANK: + nlp.vocab.vectors.add(lex.orth, row=lex.rank) + else: + if vectors_loc: + with msg.loading(f"Reading vectors from {vectors_loc}"): + vectors_data, vector_keys = read_vectors( + msg, vectors_loc, truncate_vectors + ) + msg.good(f"Loaded vectors from {vectors_loc}") + else: + vectors_data, vector_keys = (None, None) + if vector_keys is not None: + for word in vector_keys: + if word not in nlp.vocab: + nlp.vocab[word] + if vectors_data is not None: + nlp.vocab.vectors = Vectors(data=vectors_data, keys=vector_keys) + if name is None: + # TODO: Is this correct? Does this matter? + nlp.vocab.vectors.name = f"{nlp.meta['lang']}_{nlp.meta['name']}.vectors" + else: + nlp.vocab.vectors.name = name + nlp.meta["vectors"]["name"] = nlp.vocab.vectors.name + if prune_vectors >= 1: + nlp.vocab.prune_vectors(prune_vectors) + + +def read_vectors(msg: Printer, vectors_loc: Path, truncate_vectors: int): + f = open_file(vectors_loc) + f = ensure_shape(f) + shape = tuple(int(size) for size in next(f).split()) + if truncate_vectors >= 1: + shape = (truncate_vectors, shape[1]) + vectors_data = numpy.zeros(shape=shape, dtype="f") + vectors_keys = [] + for i, line in enumerate(tqdm(f)): + line = line.rstrip() + pieces = line.rsplit(" ", vectors_data.shape[1]) + word = pieces.pop(0) + if len(pieces) != vectors_data.shape[1]: + msg.fail(Errors.E094.format(line_num=i, loc=vectors_loc), exits=1) + vectors_data[i] = numpy.asarray(pieces, dtype="f") + vectors_keys.append(word) + if i == truncate_vectors - 1: + break + return vectors_data, vectors_keys + + +def ensure_shape(lines): + """Ensure that the first line of the data is the vectors shape. + + If it's not, we read in the data and output the shape as the first result, + so that the reader doesn't have to deal with the problem. + """ + first_line = next(lines) + try: + shape = tuple(int(size) for size in first_line.split()) + except ValueError: + shape = None + if shape is not None: + # All good, give the data + yield first_line + yield from lines + else: + # Figure out the shape, make it the first value, and then give the + # rest of the data. + width = len(first_line.split()) - 1 + captured = [first_line] + list(lines) + length = len(captured) + yield f"{length} {width}" + yield from captured + + +def read_freqs( + freqs_loc: Path, max_length: int = 100, min_doc_freq: int = 5, min_freq: int = 50 +): + counts = PreshCounter() + total = 0 + with freqs_loc.open() as f: + for i, line in enumerate(f): + freq, doc_freq, key = line.rstrip().split("\t", 2) + freq = int(freq) + counts.inc(i + 1, freq) + total += freq + counts.smooth() + log_total = math.log(total) + probs = {} + with freqs_loc.open() as f: + for line in tqdm(f): + freq, doc_freq, key = line.rstrip().split("\t", 2) + doc_freq = int(doc_freq) + freq = int(freq) + if doc_freq >= min_doc_freq and freq >= min_freq and len(key) < max_length: + try: + word = literal_eval(key) + except SyntaxError: + # Take odd strings literally. + word = literal_eval(f"'{key}'") + smooth_count = counts.smoother(int(freq)) + probs[word] = math.log(smooth_count) - log_total + oov_prob = math.log(counts.smoother(0)) - log_total + return probs, oov_prob + + +def read_clusters(clusters_loc: Path) -> dict: + clusters = {} + if ftfy is None: + warnings.warn(Warnings.W004) + with clusters_loc.open() as f: + for line in tqdm(f): + try: + cluster, word, freq = line.split() + if ftfy is not None: + word = ftfy.fix_text(word) + except ValueError: + continue + # If the clusterer has only seen the word a few times, its + # cluster is unreliable. + if int(freq) >= 3: + clusters[word] = cluster + else: + clusters[word] = "0" + # Expand clusters with re-casing + for word, cluster in list(clusters.items()): + if word.lower() not in clusters: + clusters[word.lower()] = cluster + if word.title() not in clusters: + clusters[word.title()] = cluster + if word.upper() not in clusters: + clusters[word.upper()] = cluster + return clusters