mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 18:06:29 +03:00
f37863093a
Remove hacks and wrappers, keep code in sync across our libraries and move spaCy a few steps closer to only depending on packages with binary wheels 🎉 See here: https://github.com/explosion/srsly Serialization is hard, especially across Python versions and multiple platforms. After dealing with many subtle bugs over the years (encodings, locales, large files) our libraries like spaCy and Prodigy have steadily grown a number of utility functions to wrap the multiple serialization formats we need to support (especially json, msgpack and pickle). These wrapping functions ended up duplicated across our codebases, so we wanted to put them in one place. At the same time, we noticed that having a lot of small dependencies was making maintainence harder, and making installation slower. To solve this, we've made srsly standalone, by including the component packages directly within it. This way we can provide all the serialization utilities we need in a single binary wheel. srsly currently includes forks of the following packages: ujson msgpack msgpack-numpy cloudpickle * WIP: replace json/ujson with srsly * Replace ujson in examples Use regular json instead of srsly to make code easier to read and follow * Update requirements * Fix imports * Fix typos * Replace msgpack with srsly * Fix warning
236 lines
8.1 KiB
Python
236 lines
8.1 KiB
Python
# coding: utf8
|
|
from __future__ import unicode_literals
|
|
|
|
import plac
|
|
import math
|
|
from tqdm import tqdm
|
|
import numpy
|
|
from ast import literal_eval
|
|
from pathlib import Path
|
|
from preshed.counter import PreshCounter
|
|
import tarfile
|
|
import gzip
|
|
import zipfile
|
|
import srsly
|
|
from wasabi import Printer
|
|
|
|
from ._messages import Messages
|
|
from ..vectors import Vectors
|
|
from ..errors import Errors, Warnings, user_warning
|
|
from ..util import ensure_path, get_lang_class
|
|
|
|
try:
|
|
import ftfy
|
|
except ImportError:
|
|
ftfy = None
|
|
|
|
|
|
msg = Printer()
|
|
|
|
|
|
@plac.annotations(
|
|
lang=("Model language", "positional", None, str),
|
|
output_dir=("Model output directory", "positional", None, Path),
|
|
freqs_loc=("Location of words frequencies file", "option", "f", Path),
|
|
jsonl_loc=("Location of JSONL-formatted attributes file", "option", "j", Path),
|
|
clusters_loc=("Optional location of brown clusters data", "option", "c", str),
|
|
vectors_loc=("Optional vectors file in Word2Vec format" "option", "v", str),
|
|
prune_vectors=("Optional number of vectors to prune to", "option", "V", int),
|
|
)
|
|
def init_model(
|
|
lang,
|
|
output_dir,
|
|
freqs_loc=None,
|
|
clusters_loc=None,
|
|
jsonl_loc=None,
|
|
vectors_loc=None,
|
|
prune_vectors=-1,
|
|
):
|
|
"""
|
|
Create a new model from raw data, like word frequencies, Brown clusters
|
|
and word vectors. If vectors are provided in Word2Vec format, they can
|
|
be either a .txt or zipped as a .zip or .tar.gz.
|
|
"""
|
|
if jsonl_loc is not None:
|
|
if freqs_loc is not None or clusters_loc is not None:
|
|
settings = ["-j"]
|
|
if freqs_loc:
|
|
settings.append("-f")
|
|
if clusters_loc:
|
|
settings.append("-c")
|
|
msg.warn(Messages.M063, Messages.M064)
|
|
jsonl_loc = ensure_path(jsonl_loc)
|
|
lex_attrs = srsly.read_jsonl(jsonl_loc)
|
|
else:
|
|
clusters_loc = ensure_path(clusters_loc)
|
|
freqs_loc = ensure_path(freqs_loc)
|
|
if freqs_loc is not None and not freqs_loc.exists():
|
|
msg.fail(Messages.M037, freqs_loc, exits=1)
|
|
lex_attrs = read_attrs_from_deprecated(freqs_loc, clusters_loc)
|
|
|
|
with msg.loading("Creating model..."):
|
|
nlp = create_model(lang, lex_attrs)
|
|
msg.good("Successfully created model")
|
|
if vectors_loc is not None:
|
|
add_vectors(nlp, vectors_loc, prune_vectors)
|
|
vec_added = len(nlp.vocab.vectors)
|
|
lex_added = len(nlp.vocab)
|
|
msg.good(Messages.M038, Messages.M039.format(entries=lex_added, vectors=vec_added))
|
|
if not output_dir.exists():
|
|
output_dir.mkdir()
|
|
nlp.to_disk(output_dir)
|
|
return nlp
|
|
|
|
|
|
def open_file(loc):
|
|
"""Handle .gz, .tar.gz or unzipped files"""
|
|
loc = ensure_path(loc)
|
|
if tarfile.is_tarfile(str(loc)):
|
|
return tarfile.open(str(loc), "r:gz")
|
|
elif loc.parts[-1].endswith("gz"):
|
|
return (line.decode("utf8") for line in gzip.open(str(loc), "r"))
|
|
elif loc.parts[-1].endswith("zip"):
|
|
zip_file = zipfile.ZipFile(str(loc))
|
|
names = zip_file.namelist()
|
|
file_ = zip_file.open(names[0])
|
|
return (line.decode("utf8") for line in file_)
|
|
else:
|
|
return loc.open("r", encoding="utf8")
|
|
|
|
|
|
def read_attrs_from_deprecated(freqs_loc, clusters_loc):
|
|
with msg.loading("Counting frequencies..."):
|
|
probs, oov_prob = read_freqs(freqs_loc) if freqs_loc is not None else ({}, -20)
|
|
msg.good("Counted frequencies")
|
|
with msg.loading("Reading clusters..."):
|
|
clusters = read_clusters(clusters_loc) if clusters_loc else {}
|
|
msg.good("Read clusters")
|
|
lex_attrs = []
|
|
sorted_probs = sorted(probs.items(), key=lambda item: item[1], reverse=True)
|
|
for i, (word, prob) in tqdm(enumerate(sorted_probs)):
|
|
attrs = {"orth": word, "id": i, "prob": prob}
|
|
# Decode as a little-endian string, so that we can do & 15 to get
|
|
# the first 4 bits. See _parse_features.pyx
|
|
if word in clusters:
|
|
attrs["cluster"] = int(clusters[word][::-1], 2)
|
|
else:
|
|
attrs["cluster"] = 0
|
|
lex_attrs.append(attrs)
|
|
return lex_attrs
|
|
|
|
|
|
def create_model(lang, lex_attrs):
|
|
lang_class = get_lang_class(lang)
|
|
nlp = lang_class()
|
|
for lexeme in nlp.vocab:
|
|
lexeme.rank = 0
|
|
lex_added = 0
|
|
for attrs in lex_attrs:
|
|
if "settings" in attrs:
|
|
continue
|
|
lexeme = nlp.vocab[attrs["orth"]]
|
|
lexeme.set_attrs(**attrs)
|
|
lexeme.is_oov = False
|
|
lex_added += 1
|
|
lex_added += 1
|
|
oov_prob = min(lex.prob for lex in nlp.vocab)
|
|
nlp.vocab.cfg.update({"oov_prob": oov_prob - 1})
|
|
return nlp
|
|
|
|
|
|
def add_vectors(nlp, vectors_loc, prune_vectors):
|
|
vectors_loc = ensure_path(vectors_loc)
|
|
if vectors_loc and vectors_loc.parts[-1].endswith(".npz"):
|
|
nlp.vocab.vectors = Vectors(data=numpy.load(vectors_loc.open("rb")))
|
|
for lex in nlp.vocab:
|
|
if lex.rank:
|
|
nlp.vocab.vectors.add(lex.orth, row=lex.rank)
|
|
else:
|
|
if vectors_loc:
|
|
with msg.loading("Reading vectors from {}".format(vectors_loc)):
|
|
vectors_data, vector_keys = read_vectors(vectors_loc)
|
|
msg.good("Loaded vectors from {}".format(vectors_loc))
|
|
else:
|
|
vectors_data, vector_keys = (None, None)
|
|
if vector_keys is not None:
|
|
for word in vector_keys:
|
|
if word not in nlp.vocab:
|
|
lexeme = nlp.vocab[word]
|
|
lexeme.is_oov = False
|
|
if vectors_data is not None:
|
|
nlp.vocab.vectors = Vectors(data=vectors_data, keys=vector_keys)
|
|
nlp.vocab.vectors.name = "%s_model.vectors" % nlp.meta["lang"]
|
|
nlp.meta["vectors"]["name"] = nlp.vocab.vectors.name
|
|
if prune_vectors >= 1:
|
|
nlp.vocab.prune_vectors(prune_vectors)
|
|
|
|
|
|
def read_vectors(vectors_loc):
|
|
f = open_file(vectors_loc)
|
|
shape = tuple(int(size) for size in next(f).split())
|
|
vectors_data = numpy.zeros(shape=shape, dtype="f")
|
|
vectors_keys = []
|
|
for i, line in enumerate(tqdm(f)):
|
|
line = line.rstrip()
|
|
pieces = line.rsplit(" ", vectors_data.shape[1] + 1)
|
|
word = pieces.pop(0)
|
|
if len(pieces) != vectors_data.shape[1]:
|
|
msg.fail(Errors.E094.format(line_num=i, loc=vectors_loc), exits=1)
|
|
vectors_data[i] = numpy.asarray(pieces, dtype="f")
|
|
vectors_keys.append(word)
|
|
return vectors_data, vectors_keys
|
|
|
|
|
|
def read_freqs(freqs_loc, max_length=100, min_doc_freq=5, min_freq=50):
|
|
counts = PreshCounter()
|
|
total = 0
|
|
with freqs_loc.open() as f:
|
|
for i, line in enumerate(f):
|
|
freq, doc_freq, key = line.rstrip().split("\t", 2)
|
|
freq = int(freq)
|
|
counts.inc(i + 1, freq)
|
|
total += freq
|
|
counts.smooth()
|
|
log_total = math.log(total)
|
|
probs = {}
|
|
with freqs_loc.open() as f:
|
|
for line in tqdm(f):
|
|
freq, doc_freq, key = line.rstrip().split("\t", 2)
|
|
doc_freq = int(doc_freq)
|
|
freq = int(freq)
|
|
if doc_freq >= min_doc_freq and freq >= min_freq and len(key) < max_length:
|
|
word = literal_eval(key)
|
|
smooth_count = counts.smoother(int(freq))
|
|
probs[word] = math.log(smooth_count) - log_total
|
|
oov_prob = math.log(counts.smoother(0)) - log_total
|
|
return probs, oov_prob
|
|
|
|
|
|
def read_clusters(clusters_loc):
|
|
clusters = {}
|
|
if ftfy is None:
|
|
user_warning(Warnings.W004)
|
|
with clusters_loc.open() as f:
|
|
for line in tqdm(f):
|
|
try:
|
|
cluster, word, freq = line.split()
|
|
if ftfy is not None:
|
|
word = ftfy.fix_text(word)
|
|
except ValueError:
|
|
continue
|
|
# If the clusterer has only seen the word a few times, its
|
|
# cluster is unreliable.
|
|
if int(freq) >= 3:
|
|
clusters[word] = cluster
|
|
else:
|
|
clusters[word] = "0"
|
|
# Expand clusters with re-casing
|
|
for word, cluster in list(clusters.items()):
|
|
if word.lower() not in clusters:
|
|
clusters[word.lower()] = cluster
|
|
if word.title() not in clusters:
|
|
clusters[word.title()] = cluster
|
|
if word.upper() not in clusters:
|
|
clusters[word.upper()] = cluster
|
|
return clusters
|