mirror of
synced 2024-11-11 04:08:09 +03:00
Remove hacks and wrappers, keep code in sync across our libraries and move spaCy a few steps closer to only depending on packages with binary wheels 🎉 See here: https://github.com/explosion/srsly Serialization is hard, especially across Python versions and multiple platforms. After dealing with many subtle bugs over the years (encodings, locales, large files) our libraries like spaCy and Prodigy have steadily grown a number of utility functions to wrap the multiple serialization formats we need to support (especially json, msgpack and pickle). These wrapping functions ended up duplicated across our codebases, so we wanted to put them in one place. At the same time, we noticed that having a lot of small dependencies was making maintainence harder, and making installation slower. To solve this, we've made srsly standalone, by including the component packages directly within it. This way we can provide all the serialization utilities we need in a single binary wheel. srsly currently includes forks of the following packages: ujson msgpack msgpack-numpy cloudpickle * WIP: replace json/ujson with srsly * Replace ujson in examples Use regular json instead of srsly to make code easier to read and follow * Update requirements * Fix imports * Fix typos * Replace msgpack with srsly * Fix warning
236 lines
8.1 KiB
236 lines
8.1 KiB
# coding: utf8
from __future__ import unicode_literals
import plac
import math
from tqdm import tqdm
import numpy
from ast import literal_eval
from pathlib import Path
from preshed.counter import PreshCounter
import tarfile
import gzip
import zipfile
import srsly
from wasabi import Printer
from ._messages import Messages
from ..vectors import Vectors
from ..errors import Errors, Warnings, user_warning
from ..util import ensure_path, get_lang_class
import ftfy
except ImportError:
ftfy = None
msg = Printer()
lang=("Model language", "positional", None, str),
output_dir=("Model output directory", "positional", None, Path),
freqs_loc=("Location of words frequencies file", "option", "f", Path),
jsonl_loc=("Location of JSONL-formatted attributes file", "option", "j", Path),
clusters_loc=("Optional location of brown clusters data", "option", "c", str),
vectors_loc=("Optional vectors file in Word2Vec format" "option", "v", str),
prune_vectors=("Optional number of vectors to prune to", "option", "V", int),
def init_model(
Create a new model from raw data, like word frequencies, Brown clusters
and word vectors. If vectors are provided in Word2Vec format, they can
be either a .txt or zipped as a .zip or .tar.gz.
if jsonl_loc is not None:
if freqs_loc is not None or clusters_loc is not None:
settings = ["-j"]
if freqs_loc:
if clusters_loc:
msg.warn(Messages.M063, Messages.M064)
jsonl_loc = ensure_path(jsonl_loc)
lex_attrs = srsly.read_jsonl(jsonl_loc)
clusters_loc = ensure_path(clusters_loc)
freqs_loc = ensure_path(freqs_loc)
if freqs_loc is not None and not freqs_loc.exists():
msg.fail(Messages.M037, freqs_loc, exits=1)
lex_attrs = read_attrs_from_deprecated(freqs_loc, clusters_loc)
with msg.loading("Creating model..."):
nlp = create_model(lang, lex_attrs)
msg.good("Successfully created model")
if vectors_loc is not None:
add_vectors(nlp, vectors_loc, prune_vectors)
vec_added = len(nlp.vocab.vectors)
lex_added = len(nlp.vocab)
msg.good(Messages.M038, Messages.M039.format(entries=lex_added, vectors=vec_added))
if not output_dir.exists():
return nlp
def open_file(loc):
"""Handle .gz, .tar.gz or unzipped files"""
loc = ensure_path(loc)
if tarfile.is_tarfile(str(loc)):
return tarfile.open(str(loc), "r:gz")
elif loc.parts[-1].endswith("gz"):
return (line.decode("utf8") for line in gzip.open(str(loc), "r"))
elif loc.parts[-1].endswith("zip"):
zip_file = zipfile.ZipFile(str(loc))
names = zip_file.namelist()
file_ = zip_file.open(names[0])
return (line.decode("utf8") for line in file_)
return loc.open("r", encoding="utf8")
def read_attrs_from_deprecated(freqs_loc, clusters_loc):
with msg.loading("Counting frequencies..."):
probs, oov_prob = read_freqs(freqs_loc) if freqs_loc is not None else ({}, -20)
msg.good("Counted frequencies")
with msg.loading("Reading clusters..."):
clusters = read_clusters(clusters_loc) if clusters_loc else {}
msg.good("Read clusters")
lex_attrs = []
sorted_probs = sorted(probs.items(), key=lambda item: item[1], reverse=True)
for i, (word, prob) in tqdm(enumerate(sorted_probs)):
attrs = {"orth": word, "id": i, "prob": prob}
# Decode as a little-endian string, so that we can do & 15 to get
# the first 4 bits. See _parse_features.pyx
if word in clusters:
attrs["cluster"] = int(clusters[word][::-1], 2)
attrs["cluster"] = 0
return lex_attrs
def create_model(lang, lex_attrs):
lang_class = get_lang_class(lang)
nlp = lang_class()
for lexeme in nlp.vocab:
lexeme.rank = 0
lex_added = 0
for attrs in lex_attrs:
if "settings" in attrs:
lexeme = nlp.vocab[attrs["orth"]]
lexeme.is_oov = False
lex_added += 1
lex_added += 1
oov_prob = min(lex.prob for lex in nlp.vocab)
nlp.vocab.cfg.update({"oov_prob": oov_prob - 1})
return nlp
def add_vectors(nlp, vectors_loc, prune_vectors):
vectors_loc = ensure_path(vectors_loc)
if vectors_loc and vectors_loc.parts[-1].endswith(".npz"):
nlp.vocab.vectors = Vectors(data=numpy.load(vectors_loc.open("rb")))
for lex in nlp.vocab:
if lex.rank:
nlp.vocab.vectors.add(lex.orth, row=lex.rank)
if vectors_loc:
with msg.loading("Reading vectors from {}".format(vectors_loc)):
vectors_data, vector_keys = read_vectors(vectors_loc)
msg.good("Loaded vectors from {}".format(vectors_loc))
vectors_data, vector_keys = (None, None)
if vector_keys is not None:
for word in vector_keys:
if word not in nlp.vocab:
lexeme = nlp.vocab[word]
lexeme.is_oov = False
if vectors_data is not None:
nlp.vocab.vectors = Vectors(data=vectors_data, keys=vector_keys)
nlp.vocab.vectors.name = "%s_model.vectors" % nlp.meta["lang"]
nlp.meta["vectors"]["name"] = nlp.vocab.vectors.name
if prune_vectors >= 1:
def read_vectors(vectors_loc):
f = open_file(vectors_loc)
shape = tuple(int(size) for size in next(f).split())
vectors_data = numpy.zeros(shape=shape, dtype="f")
vectors_keys = []
for i, line in enumerate(tqdm(f)):
line = line.rstrip()
pieces = line.rsplit(" ", vectors_data.shape[1] + 1)
word = pieces.pop(0)
if len(pieces) != vectors_data.shape[1]:
msg.fail(Errors.E094.format(line_num=i, loc=vectors_loc), exits=1)
vectors_data[i] = numpy.asarray(pieces, dtype="f")
return vectors_data, vectors_keys
def read_freqs(freqs_loc, max_length=100, min_doc_freq=5, min_freq=50):
counts = PreshCounter()
total = 0
with freqs_loc.open() as f:
for i, line in enumerate(f):
freq, doc_freq, key = line.rstrip().split("\t", 2)
freq = int(freq)
counts.inc(i + 1, freq)
total += freq
log_total = math.log(total)
probs = {}
with freqs_loc.open() as f:
for line in tqdm(f):
freq, doc_freq, key = line.rstrip().split("\t", 2)
doc_freq = int(doc_freq)
freq = int(freq)
if doc_freq >= min_doc_freq and freq >= min_freq and len(key) < max_length:
word = literal_eval(key)
smooth_count = counts.smoother(int(freq))
probs[word] = math.log(smooth_count) - log_total
oov_prob = math.log(counts.smoother(0)) - log_total
return probs, oov_prob
def read_clusters(clusters_loc):
clusters = {}
if ftfy is None:
with clusters_loc.open() as f:
for line in tqdm(f):
cluster, word, freq = line.split()
if ftfy is not None:
word = ftfy.fix_text(word)
except ValueError:
# If the clusterer has only seen the word a few times, its
# cluster is unreliable.
if int(freq) >= 3:
clusters[word] = cluster
clusters[word] = "0"
# Expand clusters with re-casing
for word, cluster in list(clusters.items()):
if word.lower() not in clusters:
clusters[word.lower()] = cluster
if word.title() not in clusters:
clusters[word.title()] = cluster
if word.upper() not in clusters:
clusters[word.upper()] = cluster
return clusters