mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 16:07:41 +03:00 
			
		
		
		
	* Add init CLI and init config draft * Improve config validation * Auto-format * Don't export anything in debug config * Update docs
		
			
				
	
	
		
			331 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			331 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from typing import Optional, List, Dict, Any, Union, IO
 | |
| import math
 | |
| from tqdm import tqdm
 | |
| import numpy
 | |
| from ast import literal_eval
 | |
| from pathlib import Path
 | |
| from preshed.counter import PreshCounter
 | |
| import tarfile
 | |
| import gzip
 | |
| import zipfile
 | |
| import srsly
 | |
| import warnings
 | |
| from wasabi import msg, Printer
 | |
| import typer
 | |
| 
 | |
| from ._util import app, init_cli, Arg, Opt
 | |
| from ..vectors import Vectors
 | |
| from ..errors import Errors, Warnings
 | |
| from ..language import Language
 | |
| from ..util import ensure_path, get_lang_class, load_model, OOV_RANK
 | |
| 
 | |
| try:
 | |
|     import ftfy
 | |
| except ImportError:
 | |
|     ftfy = None
 | |
| 
 | |
| 
 | |
| DEFAULT_OOV_PROB = -20
 | |
| 
 | |
| 
 | |
| @init_cli.command("model")
 | |
| @app.command(
 | |
|     "init-model",
 | |
|     context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
 | |
|     hidden=True,  # hide this from main CLI help but still allow it to work with warning
 | |
| )
 | |
| def init_model_cli(
 | |
|     # fmt: off
 | |
|     ctx: typer.Context,  # This is only used to read additional arguments
 | |
|     lang: str = Arg(..., help="Model language"),
 | |
|     output_dir: Path = Arg(..., help="Model output directory"),
 | |
|     freqs_loc: Optional[Path] = Arg(None, help="Location of words frequencies file", exists=True),
 | |
|     clusters_loc: Optional[Path] = Opt(None, "--clusters-loc", "-c", help="Optional location of brown clusters data", exists=True),
 | |
|     jsonl_loc: Optional[Path] = Opt(None, "--jsonl-loc", "-j", help="Location of JSONL-formatted attributes file", exists=True),
 | |
|     vectors_loc: Optional[Path] = Opt(None, "--vectors-loc", "-v", help="Optional vectors file in Word2Vec format", exists=True),
 | |
|     prune_vectors: int = Opt(-1, "--prune-vectors", "-V", help="Optional number of vectors to prune to"),
 | |
|     truncate_vectors: int = Opt(0, "--truncate-vectors", "-t", help="Optional number of vectors to truncate to when reading in vectors file"),
 | |
|     vectors_name: Optional[str] = Opt(None, "--vectors-name", "-vn", help="Optional name for the word vectors, e.g. en_core_web_lg.vectors"),
 | |
|     model_name: Optional[str] = Opt(None, "--model-name", "-mn", help="Optional name for the model meta"),
 | |
|     base_model: Optional[str] = Opt(None, "--base-model", "-b", help="Base model (for languages with custom tokenizers)")
 | |
|     # fmt: on
 | |
| ):
 | |
|     """
 | |
|     Create a new model from raw data. If vectors are provided in Word2Vec format,
 | |
|     they can be either a .txt or zipped as a .zip or .tar.gz.
 | |
|     """
 | |
|     if ctx.command.name == "init-model":
 | |
|         msg.warn(
 | |
|             "The init-model command is now available via the 'init model' "
 | |
|             "subcommand (without the hyphen). You can run python -m spacy init "
 | |
|             "--help for an overview of the other available initialization commands."
 | |
|         )
 | |
|     init_model(
 | |
|         lang,
 | |
|         output_dir,
 | |
|         freqs_loc=freqs_loc,
 | |
|         clusters_loc=clusters_loc,
 | |
|         jsonl_loc=jsonl_loc,
 | |
|         vectors_loc=vectors_loc,
 | |
|         prune_vectors=prune_vectors,
 | |
|         truncate_vectors=truncate_vectors,
 | |
|         vectors_name=vectors_name,
 | |
|         model_name=model_name,
 | |
|         base_model=base_model,
 | |
|         silent=False,
 | |
|     )
 | |
| 
 | |
| 
 | |
| def init_model(
 | |
|     lang: str,
 | |
|     output_dir: Path,
 | |
|     freqs_loc: Optional[Path] = None,
 | |
|     clusters_loc: Optional[Path] = None,
 | |
|     jsonl_loc: Optional[Path] = None,
 | |
|     vectors_loc: Optional[Path] = None,
 | |
|     prune_vectors: int = -1,
 | |
|     truncate_vectors: int = 0,
 | |
|     vectors_name: Optional[str] = None,
 | |
|     model_name: Optional[str] = None,
 | |
|     base_model: Optional[str] = None,
 | |
|     silent: bool = True,
 | |
| ) -> Language:
 | |
|     msg = Printer(no_print=silent, pretty=not silent)
 | |
|     if jsonl_loc is not None:
 | |
|         if freqs_loc is not None or clusters_loc is not None:
 | |
|             settings = ["-j"]
 | |
|             if freqs_loc:
 | |
|                 settings.append("-f")
 | |
|             if clusters_loc:
 | |
|                 settings.append("-c")
 | |
|             msg.warn(
 | |
|                 "Incompatible arguments",
 | |
|                 "The -f and -c arguments are deprecated, and not compatible "
 | |
|                 "with the -j argument, which should specify the same "
 | |
|                 "information. Either merge the frequencies and clusters data "
 | |
|                 "into the JSONL-formatted file (recommended), or use only the "
 | |
|                 "-f and -c files, without the other lexical attributes.",
 | |
|             )
 | |
|         jsonl_loc = ensure_path(jsonl_loc)
 | |
|         lex_attrs = srsly.read_jsonl(jsonl_loc)
 | |
|     else:
 | |
|         clusters_loc = ensure_path(clusters_loc)
 | |
|         freqs_loc = ensure_path(freqs_loc)
 | |
|         if freqs_loc is not None and not freqs_loc.exists():
 | |
|             msg.fail("Can't find words frequencies file", freqs_loc, exits=1)
 | |
|         lex_attrs = read_attrs_from_deprecated(msg, freqs_loc, clusters_loc)
 | |
| 
 | |
|     with msg.loading("Creating model..."):
 | |
|         nlp = create_model(lang, lex_attrs, name=model_name, base_model=base_model)
 | |
| 
 | |
|     msg.good("Successfully created model")
 | |
|     if vectors_loc is not None:
 | |
|         add_vectors(
 | |
|             msg, nlp, vectors_loc, truncate_vectors, prune_vectors, vectors_name
 | |
|         )
 | |
|     vec_added = len(nlp.vocab.vectors)
 | |
|     lex_added = len(nlp.vocab)
 | |
|     msg.good(
 | |
|         "Sucessfully compiled vocab", f"{lex_added} entries, {vec_added} vectors",
 | |
|     )
 | |
|     if not output_dir.exists():
 | |
|         output_dir.mkdir()
 | |
|     nlp.to_disk(output_dir)
 | |
|     return nlp
 | |
| 
 | |
| 
 | |
| def open_file(loc: Union[str, Path]) -> IO:
 | |
|     """Handle .gz, .tar.gz or unzipped files"""
 | |
|     loc = ensure_path(loc)
 | |
|     if tarfile.is_tarfile(str(loc)):
 | |
|         return tarfile.open(str(loc), "r:gz")
 | |
|     elif loc.parts[-1].endswith("gz"):
 | |
|         return (line.decode("utf8") for line in gzip.open(str(loc), "r"))
 | |
|     elif loc.parts[-1].endswith("zip"):
 | |
|         zip_file = zipfile.ZipFile(str(loc))
 | |
|         names = zip_file.namelist()
 | |
|         file_ = zip_file.open(names[0])
 | |
|         return (line.decode("utf8") for line in file_)
 | |
|     else:
 | |
|         return loc.open("r", encoding="utf8")
 | |
| 
 | |
| 
 | |
| def read_attrs_from_deprecated(
 | |
|     msg: Printer, freqs_loc: Optional[Path], clusters_loc: Optional[Path]
 | |
| ) -> List[Dict[str, Any]]:
 | |
|     if freqs_loc is not None:
 | |
|         with msg.loading("Counting frequencies..."):
 | |
|             probs, _ = read_freqs(freqs_loc)
 | |
|         msg.good("Counted frequencies")
 | |
|     else:
 | |
|         probs, _ = ({}, DEFAULT_OOV_PROB)  # noqa: F841
 | |
|     if clusters_loc:
 | |
|         with msg.loading("Reading clusters..."):
 | |
|             clusters = read_clusters(clusters_loc)
 | |
|         msg.good("Read clusters")
 | |
|     else:
 | |
|         clusters = {}
 | |
|     lex_attrs = []
 | |
|     sorted_probs = sorted(probs.items(), key=lambda item: item[1], reverse=True)
 | |
|     if len(sorted_probs):
 | |
|         for i, (word, prob) in tqdm(enumerate(sorted_probs)):
 | |
|             attrs = {"orth": word, "id": i, "prob": prob}
 | |
|             # Decode as a little-endian string, so that we can do & 15 to get
 | |
|             # the first 4 bits. See _parse_features.pyx
 | |
|             if word in clusters:
 | |
|                 attrs["cluster"] = int(clusters[word][::-1], 2)
 | |
|             else:
 | |
|                 attrs["cluster"] = 0
 | |
|             lex_attrs.append(attrs)
 | |
|     return lex_attrs
 | |
| 
 | |
| 
 | |
| def create_model(
 | |
|     lang: str,
 | |
|     lex_attrs: List[Dict[str, Any]],
 | |
|     name: Optional[str] = None,
 | |
|     base_model: Optional[Union[str, Path]] = None,
 | |
| ) -> Language:
 | |
|     if base_model:
 | |
|         nlp = load_model(base_model)
 | |
|         # keep the tokenizer but remove any existing pipeline components due to
 | |
|         # potentially conflicting vectors
 | |
|         for pipe in nlp.pipe_names:
 | |
|             nlp.remove_pipe(pipe)
 | |
|     else:
 | |
|         lang_class = get_lang_class(lang)
 | |
|         nlp = lang_class()
 | |
|     for lexeme in nlp.vocab:
 | |
|         lexeme.rank = OOV_RANK
 | |
|     for attrs in lex_attrs:
 | |
|         if "settings" in attrs:
 | |
|             continue
 | |
|         lexeme = nlp.vocab[attrs["orth"]]
 | |
|         lexeme.set_attrs(**attrs)
 | |
|     if len(nlp.vocab):
 | |
|         oov_prob = min(lex.prob for lex in nlp.vocab) - 1
 | |
|     else:
 | |
|         oov_prob = DEFAULT_OOV_PROB
 | |
|     nlp.vocab.cfg.update({"oov_prob": oov_prob})
 | |
|     if name:
 | |
|         nlp.meta["name"] = name
 | |
|     return nlp
 | |
| 
 | |
| 
 | |
| def add_vectors(
 | |
|     msg: Printer,
 | |
|     nlp: Language,
 | |
|     vectors_loc: Optional[Path],
 | |
|     truncate_vectors: int,
 | |
|     prune_vectors: int,
 | |
|     name: Optional[str] = None,
 | |
| ) -> None:
 | |
|     vectors_loc = ensure_path(vectors_loc)
 | |
|     if vectors_loc and vectors_loc.parts[-1].endswith(".npz"):
 | |
|         nlp.vocab.vectors = Vectors(data=numpy.load(vectors_loc.open("rb")))
 | |
|         for lex in nlp.vocab:
 | |
|             if lex.rank and lex.rank != OOV_RANK:
 | |
|                 nlp.vocab.vectors.add(lex.orth, row=lex.rank)
 | |
|     else:
 | |
|         if vectors_loc:
 | |
|             with msg.loading(f"Reading vectors from {vectors_loc}"):
 | |
|                 vectors_data, vector_keys = read_vectors(
 | |
|                     msg, vectors_loc, truncate_vectors
 | |
|                 )
 | |
|             msg.good(f"Loaded vectors from {vectors_loc}")
 | |
|         else:
 | |
|             vectors_data, vector_keys = (None, None)
 | |
|         if vector_keys is not None:
 | |
|             for word in vector_keys:
 | |
|                 if word not in nlp.vocab:
 | |
|                     nlp.vocab[word]
 | |
|         if vectors_data is not None:
 | |
|             nlp.vocab.vectors = Vectors(data=vectors_data, keys=vector_keys)
 | |
|     if name is None:
 | |
|         nlp.vocab.vectors.name = f"{nlp.meta['lang']}_model.vectors"
 | |
|     else:
 | |
|         nlp.vocab.vectors.name = name
 | |
|     nlp.meta["vectors"]["name"] = nlp.vocab.vectors.name
 | |
|     if prune_vectors >= 1:
 | |
|         nlp.vocab.prune_vectors(prune_vectors)
 | |
| 
 | |
| 
 | |
| def read_vectors(msg: Printer, vectors_loc: Path, truncate_vectors: int):
 | |
|     f = open_file(vectors_loc)
 | |
|     shape = tuple(int(size) for size in next(f).split())
 | |
|     if truncate_vectors >= 1:
 | |
|         shape = (truncate_vectors, shape[1])
 | |
|     vectors_data = numpy.zeros(shape=shape, dtype="f")
 | |
|     vectors_keys = []
 | |
|     for i, line in enumerate(tqdm(f)):
 | |
|         line = line.rstrip()
 | |
|         pieces = line.rsplit(" ", vectors_data.shape[1])
 | |
|         word = pieces.pop(0)
 | |
|         if len(pieces) != vectors_data.shape[1]:
 | |
|             msg.fail(Errors.E094.format(line_num=i, loc=vectors_loc), exits=1)
 | |
|         vectors_data[i] = numpy.asarray(pieces, dtype="f")
 | |
|         vectors_keys.append(word)
 | |
|         if i == truncate_vectors - 1:
 | |
|             break
 | |
|     return vectors_data, vectors_keys
 | |
| 
 | |
| 
 | |
| def read_freqs(
 | |
|     freqs_loc: Path, max_length: int = 100, min_doc_freq: int = 5, min_freq: int = 50
 | |
| ):
 | |
|     counts = PreshCounter()
 | |
|     total = 0
 | |
|     with freqs_loc.open() as f:
 | |
|         for i, line in enumerate(f):
 | |
|             freq, doc_freq, key = line.rstrip().split("\t", 2)
 | |
|             freq = int(freq)
 | |
|             counts.inc(i + 1, freq)
 | |
|             total += freq
 | |
|     counts.smooth()
 | |
|     log_total = math.log(total)
 | |
|     probs = {}
 | |
|     with freqs_loc.open() as f:
 | |
|         for line in tqdm(f):
 | |
|             freq, doc_freq, key = line.rstrip().split("\t", 2)
 | |
|             doc_freq = int(doc_freq)
 | |
|             freq = int(freq)
 | |
|             if doc_freq >= min_doc_freq and freq >= min_freq and len(key) < max_length:
 | |
|                 try:
 | |
|                     word = literal_eval(key)
 | |
|                 except SyntaxError:
 | |
|                     # Take odd strings literally.
 | |
|                     word = literal_eval(f"'{key}'")
 | |
|                 smooth_count = counts.smoother(int(freq))
 | |
|                 probs[word] = math.log(smooth_count) - log_total
 | |
|     oov_prob = math.log(counts.smoother(0)) - log_total
 | |
|     return probs, oov_prob
 | |
| 
 | |
| 
 | |
| def read_clusters(clusters_loc: Path) -> dict:
 | |
|     clusters = {}
 | |
|     if ftfy is None:
 | |
|         warnings.warn(Warnings.W004)
 | |
|     with clusters_loc.open() as f:
 | |
|         for line in tqdm(f):
 | |
|             try:
 | |
|                 cluster, word, freq = line.split()
 | |
|                 if ftfy is not None:
 | |
|                     word = ftfy.fix_text(word)
 | |
|             except ValueError:
 | |
|                 continue
 | |
|             # If the clusterer has only seen the word a few times, its
 | |
|             # cluster is unreliable.
 | |
|             if int(freq) >= 3:
 | |
|                 clusters[word] = cluster
 | |
|             else:
 | |
|                 clusters[word] = "0"
 | |
|     # Expand clusters with re-casing
 | |
|     for word, cluster in list(clusters.items()):
 | |
|         if word.lower() not in clusters:
 | |
|             clusters[word.lower()] = cluster
 | |
|         if word.title() not in clusters:
 | |
|             clusters[word.title()] = cluster
 | |
|         if word.upper() not in clusters:
 | |
|             clusters[word.upper()] = cluster
 | |
|     return clusters
 |