from typing import Optional, List, Dict, Any, Union, IO import math from tqdm import tqdm import numpy from ast import literal_eval from pathlib import Path from preshed.counter import PreshCounter import tarfile import gzip import zipfile import srsly import warnings from wasabi import Printer from ._util import app, Arg, Opt from ..vectors import Vectors from ..errors import Errors, Warnings from ..language import Language from ..util import ensure_path, get_lang_class, load_model, OOV_RANK from ..lookups import Lookups try: import ftfy except ImportError: ftfy = None DEFAULT_OOV_PROB = -20 @app.command("init-model") def init_model_cli( # fmt: off lang: str = Arg(..., help="Model language"), output_dir: Path = Arg(..., help="Model output directory"), freqs_loc: Optional[Path] = Arg(None, help="Location of words frequencies file", exists=True), clusters_loc: Optional[Path] = Opt(None, "--clusters-loc", "-c", help="Optional location of brown clusters data", exists=True), jsonl_loc: Optional[Path] = Opt(None, "--jsonl-loc", "-j", help="Location of JSONL-formatted attributes file", exists=True), vectors_loc: Optional[Path] = Opt(None, "--vectors-loc", "-v", help="Optional vectors file in Word2Vec format", exists=True), prune_vectors: int = Opt(-1, "--prune-vectors", "-V", help="Optional number of vectors to prune to"), truncate_vectors: int = Opt(0, "--truncate-vectors", "-t", help="Optional number of vectors to truncate to when reading in vectors file"), vectors_name: Optional[str] = Opt(None, "--vectors-name", "-vn", help="Optional name for the word vectors, e.g. en_core_web_lg.vectors"), model_name: Optional[str] = Opt(None, "--model-name", "-mn", help="Optional name for the model meta"), base_model: Optional[str] = Opt(None, "--base-model", "-b", help="Base model (for languages with custom tokenizers)") # fmt: on ): """ Create a new model from raw data. If vectors are provided in Word2Vec format, they can be either a .txt or zipped as a .zip or .tar.gz. """ init_model( lang, output_dir, freqs_loc=freqs_loc, clusters_loc=clusters_loc, jsonl_loc=jsonl_loc, vectors_loc=vectors_loc, prune_vectors=prune_vectors, truncate_vectors=truncate_vectors, vectors_name=vectors_name, model_name=model_name, base_model=base_model, silent=False, ) def init_model( lang: str, output_dir: Path, freqs_loc: Optional[Path] = None, clusters_loc: Optional[Path] = None, jsonl_loc: Optional[Path] = None, vectors_loc: Optional[Path] = None, prune_vectors: int = -1, truncate_vectors: int = 0, vectors_name: Optional[str] = None, model_name: Optional[str] = None, base_model: Optional[str] = None, silent: bool = True, ) -> Language: msg = Printer(no_print=silent, pretty=not silent) if jsonl_loc is not None: if freqs_loc is not None or clusters_loc is not None: settings = ["-j"] if freqs_loc: settings.append("-f") if clusters_loc: settings.append("-c") msg.warn( "Incompatible arguments", "The -f and -c arguments are deprecated, and not compatible " "with the -j argument, which should specify the same " "information. Either merge the frequencies and clusters data " "into the JSONL-formatted file (recommended), or use only the " "-f and -c files, without the other lexical attributes.", ) jsonl_loc = ensure_path(jsonl_loc) lex_attrs = srsly.read_jsonl(jsonl_loc) else: clusters_loc = ensure_path(clusters_loc) freqs_loc = ensure_path(freqs_loc) if freqs_loc is not None and not freqs_loc.exists(): msg.fail("Can't find words frequencies file", freqs_loc, exits=1) lex_attrs = read_attrs_from_deprecated(msg, freqs_loc, clusters_loc) with msg.loading("Creating model..."): nlp = create_model(lang, lex_attrs, name=model_name, base_model=base_model) msg.good("Successfully created model") if vectors_loc is not None: add_vectors( msg, nlp, vectors_loc, truncate_vectors, prune_vectors, vectors_name ) vec_added = len(nlp.vocab.vectors) lex_added = len(nlp.vocab) msg.good( "Sucessfully compiled vocab", f"{lex_added} entries, {vec_added} vectors", ) if not output_dir.exists(): output_dir.mkdir() nlp.to_disk(output_dir) return nlp def open_file(loc: Union[str, Path]) -> IO: """Handle .gz, .tar.gz or unzipped files""" loc = ensure_path(loc) if tarfile.is_tarfile(str(loc)): return tarfile.open(str(loc), "r:gz") elif loc.parts[-1].endswith("gz"): return (line.decode("utf8") for line in gzip.open(str(loc), "r")) elif loc.parts[-1].endswith("zip"): zip_file = zipfile.ZipFile(str(loc)) names = zip_file.namelist() file_ = zip_file.open(names[0]) return (line.decode("utf8") for line in file_) else: return loc.open("r", encoding="utf8") def read_attrs_from_deprecated( msg: Printer, freqs_loc: Optional[Path], clusters_loc: Optional[Path] ) -> List[Dict[str, Any]]: if freqs_loc is not None: with msg.loading("Counting frequencies..."): probs, _ = read_freqs(freqs_loc) msg.good("Counted frequencies") else: probs, _ = ({}, DEFAULT_OOV_PROB) # noqa: F841 if clusters_loc: with msg.loading("Reading clusters..."): clusters = read_clusters(clusters_loc) msg.good("Read clusters") else: clusters = {} lex_attrs = [] sorted_probs = sorted(probs.items(), key=lambda item: item[1], reverse=True) if len(sorted_probs): for i, (word, prob) in tqdm(enumerate(sorted_probs)): attrs = {"orth": word, "id": i, "prob": prob} # Decode as a little-endian string, so that we can do & 15 to get # the first 4 bits. See _parse_features.pyx if word in clusters: attrs["cluster"] = int(clusters[word][::-1], 2) else: attrs["cluster"] = 0 lex_attrs.append(attrs) return lex_attrs def create_model( lang: str, lex_attrs: List[Dict[str, Any]], name: Optional[str] = None, base_model: Optional[Union[str, Path]] = None, ) -> Language: if base_model: nlp = load_model(base_model) # keep the tokenizer but remove any existing pipeline components due to # potentially conflicting vectors for pipe in nlp.pipe_names: nlp.remove_pipe(pipe) else: lang_class = get_lang_class(lang) nlp = lang_class() for lexeme in nlp.vocab: lexeme.rank = OOV_RANK for attrs in lex_attrs: if "settings" in attrs: continue lexeme = nlp.vocab[attrs["orth"]] lexeme.set_attrs(**attrs) if len(nlp.vocab): oov_prob = min(lex.prob for lex in nlp.vocab) - 1 else: oov_prob = DEFAULT_OOV_PROB nlp.vocab.cfg.update({"oov_prob": oov_prob}) if name: nlp.meta["name"] = name return nlp def add_vectors( msg: Printer, nlp: Language, vectors_loc: Optional[Path], truncate_vectors: int, prune_vectors: int, name: Optional[str] = None, ) -> None: vectors_loc = ensure_path(vectors_loc) if vectors_loc and vectors_loc.parts[-1].endswith(".npz"): nlp.vocab.vectors = Vectors(data=numpy.load(vectors_loc.open("rb"))) for lex in nlp.vocab: if lex.rank and lex.rank != OOV_RANK: nlp.vocab.vectors.add(lex.orth, row=lex.rank) else: if vectors_loc: with msg.loading(f"Reading vectors from {vectors_loc}"): vectors_data, vector_keys = read_vectors( msg, vectors_loc, truncate_vectors ) msg.good(f"Loaded vectors from {vectors_loc}") else: vectors_data, vector_keys = (None, None) if vector_keys is not None: for word in vector_keys: if word not in nlp.vocab: nlp.vocab[word] if vectors_data is not None: nlp.vocab.vectors = Vectors(data=vectors_data, keys=vector_keys) if name is None: nlp.vocab.vectors.name = f"{nlp.meta['lang']}_model.vectors" else: nlp.vocab.vectors.name = name nlp.meta["vectors"]["name"] = nlp.vocab.vectors.name if prune_vectors >= 1: nlp.vocab.prune_vectors(prune_vectors) def read_vectors(msg: Printer, vectors_loc: Path, truncate_vectors: int): f = open_file(vectors_loc) shape = tuple(int(size) for size in next(f).split()) if truncate_vectors >= 1: shape = (truncate_vectors, shape[1]) vectors_data = numpy.zeros(shape=shape, dtype="f") vectors_keys = [] for i, line in enumerate(tqdm(f)): line = line.rstrip() pieces = line.rsplit(" ", vectors_data.shape[1]) word = pieces.pop(0) if len(pieces) != vectors_data.shape[1]: msg.fail(Errors.E094.format(line_num=i, loc=vectors_loc), exits=1) vectors_data[i] = numpy.asarray(pieces, dtype="f") vectors_keys.append(word) if i == truncate_vectors - 1: break return vectors_data, vectors_keys def read_freqs( freqs_loc: Path, max_length: int = 100, min_doc_freq: int = 5, min_freq: int = 50 ): counts = PreshCounter() total = 0 with freqs_loc.open() as f: for i, line in enumerate(f): freq, doc_freq, key = line.rstrip().split("\t", 2) freq = int(freq) counts.inc(i + 1, freq) total += freq counts.smooth() log_total = math.log(total) probs = {} with freqs_loc.open() as f: for line in tqdm(f): freq, doc_freq, key = line.rstrip().split("\t", 2) doc_freq = int(doc_freq) freq = int(freq) if doc_freq >= min_doc_freq and freq >= min_freq and len(key) < max_length: try: word = literal_eval(key) except SyntaxError: # Take odd strings literally. word = literal_eval(f"'{key}'") smooth_count = counts.smoother(int(freq)) probs[word] = math.log(smooth_count) - log_total oov_prob = math.log(counts.smoother(0)) - log_total return probs, oov_prob def read_clusters(clusters_loc: Path) -> dict: clusters = {} if ftfy is None: warnings.warn(Warnings.W004) with clusters_loc.open() as f: for line in tqdm(f): try: cluster, word, freq = line.split() if ftfy is not None: word = ftfy.fix_text(word) except ValueError: continue # If the clusterer has only seen the word a few times, its # cluster is unreliable. if int(freq) >= 3: clusters[word] = cluster else: clusters[word] = "0" # Expand clusters with re-casing for word, cluster in list(clusters.items()): if word.lower() not in clusters: clusters[word.lower()] = cluster if word.title() not in clusters: clusters[word.title()] = cluster if word.upper() not in clusters: clusters[word.upper()] = cluster return clusters