2017-11-27 01:21:47 +03:00
|
|
|
# coding: utf8
|
|
|
|
from __future__ import unicode_literals
|
|
|
|
|
|
|
|
import plac
|
|
|
|
import math
|
|
|
|
from tqdm import tqdm
|
|
|
|
import numpy
|
|
|
|
from ast import literal_eval
|
|
|
|
from pathlib import Path
|
|
|
|
from preshed.counter import PreshCounter
|
2018-03-21 16:33:23 +03:00
|
|
|
import tarfile
|
|
|
|
import gzip
|
2018-03-28 00:01:18 +03:00
|
|
|
import zipfile
|
💫 Replace ujson, msgpack and dill/pickle/cloudpickle with srsly (#3003)
Remove hacks and wrappers, keep code in sync across our libraries and move spaCy a few steps closer to only depending on packages with binary wheels 🎉
See here: https://github.com/explosion/srsly
Serialization is hard, especially across Python versions and multiple platforms. After dealing with many subtle bugs over the years (encodings, locales, large files) our libraries like spaCy and Prodigy have steadily grown a number of utility functions to wrap the multiple serialization formats we need to support (especially json, msgpack and pickle). These wrapping functions ended up duplicated across our codebases, so we wanted to put them in one place.
At the same time, we noticed that having a lot of small dependencies was making maintainence harder, and making installation slower. To solve this, we've made srsly standalone, by including the component packages directly within it. This way we can provide all the serialization utilities we need in a single binary wheel.
srsly currently includes forks of the following packages:
ujson
msgpack
msgpack-numpy
cloudpickle
* WIP: replace json/ujson with srsly
* Replace ujson in examples
Use regular json instead of srsly to make code easier to read and follow
* Update requirements
* Fix imports
* Fix typos
* Replace msgpack with srsly
* Fix warning
2018-12-03 03:28:22 +03:00
|
|
|
import srsly
|
2018-11-30 22:16:14 +03:00
|
|
|
from wasabi import Printer
|
2017-11-27 01:21:47 +03:00
|
|
|
|
2017-12-07 12:03:07 +03:00
|
|
|
from ..vectors import Vectors
|
2018-04-10 22:26:37 +03:00
|
|
|
from ..errors import Errors, Warnings, user_warning
|
💫 Replace ujson, msgpack and dill/pickle/cloudpickle with srsly (#3003)
Remove hacks and wrappers, keep code in sync across our libraries and move spaCy a few steps closer to only depending on packages with binary wheels 🎉
See here: https://github.com/explosion/srsly
Serialization is hard, especially across Python versions and multiple platforms. After dealing with many subtle bugs over the years (encodings, locales, large files) our libraries like spaCy and Prodigy have steadily grown a number of utility functions to wrap the multiple serialization formats we need to support (especially json, msgpack and pickle). These wrapping functions ended up duplicated across our codebases, so we wanted to put them in one place.
At the same time, we noticed that having a lot of small dependencies was making maintainence harder, and making installation slower. To solve this, we've made srsly standalone, by including the component packages directly within it. This way we can provide all the serialization utilities we need in a single binary wheel.
srsly currently includes forks of the following packages:
ujson
msgpack
msgpack-numpy
cloudpickle
* WIP: replace json/ujson with srsly
* Replace ujson in examples
Use regular json instead of srsly to make code easier to read and follow
* Update requirements
* Fix imports
* Fix typos
* Replace msgpack with srsly
* Fix warning
2018-12-03 03:28:22 +03:00
|
|
|
from ..util import ensure_path, get_lang_class
|
2017-11-27 01:21:47 +03:00
|
|
|
|
2018-04-10 20:08:06 +03:00
|
|
|
try:
|
|
|
|
import ftfy
|
|
|
|
except ImportError:
|
|
|
|
ftfy = None
|
|
|
|
|
2017-11-27 01:21:47 +03:00
|
|
|
|
2019-08-01 18:26:09 +03:00
|
|
|
DEFAULT_OOV_PROB = -20
|
2018-11-30 22:16:14 +03:00
|
|
|
msg = Printer()
|
|
|
|
|
|
|
|
|
2017-11-27 01:21:47 +03:00
|
|
|
@plac.annotations(
|
2018-11-30 22:16:14 +03:00
|
|
|
lang=("Model language", "positional", None, str),
|
|
|
|
output_dir=("Model output directory", "positional", None, Path),
|
|
|
|
freqs_loc=("Location of words frequencies file", "option", "f", Path),
|
|
|
|
jsonl_loc=("Location of JSONL-formatted attributes file", "option", "j", Path),
|
|
|
|
clusters_loc=("Optional location of brown clusters data", "option", "c", str),
|
2018-12-07 00:48:31 +03:00
|
|
|
vectors_loc=("Optional vectors file in Word2Vec format", "option", "v", str),
|
2018-11-30 22:16:14 +03:00
|
|
|
prune_vectors=("Optional number of vectors to prune to", "option", "V", int),
|
2017-11-27 01:21:47 +03:00
|
|
|
)
|
2018-11-30 22:16:14 +03:00
|
|
|
def init_model(
|
|
|
|
lang,
|
|
|
|
output_dir,
|
|
|
|
freqs_loc=None,
|
|
|
|
clusters_loc=None,
|
|
|
|
jsonl_loc=None,
|
|
|
|
vectors_loc=None,
|
|
|
|
prune_vectors=-1,
|
|
|
|
):
|
2017-12-07 12:23:09 +03:00
|
|
|
"""
|
|
|
|
Create a new model from raw data, like word frequencies, Brown clusters
|
2018-11-30 22:16:14 +03:00
|
|
|
and word vectors. If vectors are provided in Word2Vec format, they can
|
|
|
|
be either a .txt or zipped as a .zip or .tar.gz.
|
2017-12-07 12:23:09 +03:00
|
|
|
"""
|
2018-07-03 13:22:56 +03:00
|
|
|
if jsonl_loc is not None:
|
|
|
|
if freqs_loc is not None or clusters_loc is not None:
|
2018-11-30 22:16:14 +03:00
|
|
|
settings = ["-j"]
|
2018-07-03 13:22:56 +03:00
|
|
|
if freqs_loc:
|
2018-11-30 22:16:14 +03:00
|
|
|
settings.append("-f")
|
2018-07-03 13:22:56 +03:00
|
|
|
if clusters_loc:
|
2018-11-30 22:16:14 +03:00
|
|
|
settings.append("-c")
|
2018-12-08 13:49:43 +03:00
|
|
|
msg.warn(
|
|
|
|
"Incompatible arguments",
|
|
|
|
"The -f and -c arguments are deprecated, and not compatible "
|
|
|
|
"with the -j argument, which should specify the same "
|
|
|
|
"information. Either merge the frequencies and clusters data "
|
|
|
|
"into the JSONL-formatted file (recommended), or use only the "
|
|
|
|
"-f and -c files, without the other lexical attributes.",
|
|
|
|
)
|
2018-07-03 13:22:56 +03:00
|
|
|
jsonl_loc = ensure_path(jsonl_loc)
|
💫 Replace ujson, msgpack and dill/pickle/cloudpickle with srsly (#3003)
Remove hacks and wrappers, keep code in sync across our libraries and move spaCy a few steps closer to only depending on packages with binary wheels 🎉
See here: https://github.com/explosion/srsly
Serialization is hard, especially across Python versions and multiple platforms. After dealing with many subtle bugs over the years (encodings, locales, large files) our libraries like spaCy and Prodigy have steadily grown a number of utility functions to wrap the multiple serialization formats we need to support (especially json, msgpack and pickle). These wrapping functions ended up duplicated across our codebases, so we wanted to put them in one place.
At the same time, we noticed that having a lot of small dependencies was making maintainence harder, and making installation slower. To solve this, we've made srsly standalone, by including the component packages directly within it. This way we can provide all the serialization utilities we need in a single binary wheel.
srsly currently includes forks of the following packages:
ujson
msgpack
msgpack-numpy
cloudpickle
* WIP: replace json/ujson with srsly
* Replace ujson in examples
Use regular json instead of srsly to make code easier to read and follow
* Update requirements
* Fix imports
* Fix typos
* Replace msgpack with srsly
* Fix warning
2018-12-03 03:28:22 +03:00
|
|
|
lex_attrs = srsly.read_jsonl(jsonl_loc)
|
2018-07-03 13:22:56 +03:00
|
|
|
else:
|
|
|
|
clusters_loc = ensure_path(clusters_loc)
|
|
|
|
freqs_loc = ensure_path(freqs_loc)
|
|
|
|
if freqs_loc is not None and not freqs_loc.exists():
|
2018-12-08 13:49:43 +03:00
|
|
|
msg.fail("Can't find words frequencies file", freqs_loc, exits=1)
|
2018-07-03 13:22:56 +03:00
|
|
|
lex_attrs = read_attrs_from_deprecated(freqs_loc, clusters_loc)
|
2018-07-04 03:29:48 +03:00
|
|
|
|
2018-11-30 22:16:14 +03:00
|
|
|
with msg.loading("Creating model..."):
|
|
|
|
nlp = create_model(lang, lex_attrs)
|
|
|
|
msg.good("Successfully created model")
|
2018-07-04 03:29:48 +03:00
|
|
|
if vectors_loc is not None:
|
|
|
|
add_vectors(nlp, vectors_loc, prune_vectors)
|
|
|
|
vec_added = len(nlp.vocab.vectors)
|
|
|
|
lex_added = len(nlp.vocab)
|
2018-12-08 13:49:43 +03:00
|
|
|
msg.good(
|
|
|
|
"Sucessfully compiled vocab",
|
|
|
|
"{} entries, {} vectors".format(lex_added, vec_added),
|
|
|
|
)
|
2017-11-27 01:21:47 +03:00
|
|
|
if not output_dir.exists():
|
|
|
|
output_dir.mkdir()
|
|
|
|
nlp.to_disk(output_dir)
|
|
|
|
return nlp
|
|
|
|
|
2018-11-30 22:16:14 +03:00
|
|
|
|
2018-03-21 16:33:23 +03:00
|
|
|
def open_file(loc):
|
2018-11-30 22:16:14 +03:00
|
|
|
"""Handle .gz, .tar.gz or unzipped files"""
|
2018-03-21 16:33:23 +03:00
|
|
|
loc = ensure_path(loc)
|
|
|
|
if tarfile.is_tarfile(str(loc)):
|
2018-11-30 22:16:14 +03:00
|
|
|
return tarfile.open(str(loc), "r:gz")
|
|
|
|
elif loc.parts[-1].endswith("gz"):
|
|
|
|
return (line.decode("utf8") for line in gzip.open(str(loc), "r"))
|
|
|
|
elif loc.parts[-1].endswith("zip"):
|
2018-03-28 00:01:18 +03:00
|
|
|
zip_file = zipfile.ZipFile(str(loc))
|
|
|
|
names = zip_file.namelist()
|
|
|
|
file_ = zip_file.open(names[0])
|
2018-11-30 22:16:14 +03:00
|
|
|
return (line.decode("utf8") for line in file_)
|
2018-03-21 16:33:23 +03:00
|
|
|
else:
|
2018-11-30 22:16:14 +03:00
|
|
|
return loc.open("r", encoding="utf8")
|
|
|
|
|
2018-03-21 16:33:23 +03:00
|
|
|
|
2018-07-03 13:22:56 +03:00
|
|
|
def read_attrs_from_deprecated(freqs_loc, clusters_loc):
|
2019-08-01 18:26:09 +03:00
|
|
|
if freqs_loc is not None:
|
|
|
|
with msg.loading("Counting frequencies..."):
|
|
|
|
probs, _ = read_freqs(freqs_loc)
|
|
|
|
msg.good("Counted frequencies")
|
|
|
|
else:
|
2019-08-18 16:09:16 +03:00
|
|
|
probs, _ = ({}, DEFAULT_OOV_PROB) # noqa: F841
|
2019-08-01 18:26:09 +03:00
|
|
|
if clusters_loc:
|
|
|
|
with msg.loading("Reading clusters..."):
|
|
|
|
clusters = read_clusters(clusters_loc)
|
|
|
|
msg.good("Read clusters")
|
|
|
|
else:
|
|
|
|
clusters = {}
|
2018-08-14 14:19:15 +03:00
|
|
|
lex_attrs = []
|
2018-07-03 13:22:56 +03:00
|
|
|
sorted_probs = sorted(probs.items(), key=lambda item: item[1], reverse=True)
|
2019-08-01 18:26:09 +03:00
|
|
|
if len(sorted_probs):
|
|
|
|
for i, (word, prob) in tqdm(enumerate(sorted_probs)):
|
|
|
|
attrs = {"orth": word, "id": i, "prob": prob}
|
|
|
|
# Decode as a little-endian string, so that we can do & 15 to get
|
|
|
|
# the first 4 bits. See _parse_features.pyx
|
|
|
|
if word in clusters:
|
|
|
|
attrs["cluster"] = int(clusters[word][::-1], 2)
|
|
|
|
else:
|
|
|
|
attrs["cluster"] = 0
|
|
|
|
lex_attrs.append(attrs)
|
2018-07-03 13:22:56 +03:00
|
|
|
return lex_attrs
|
|
|
|
|
|
|
|
|
2018-07-04 03:29:48 +03:00
|
|
|
def create_model(lang, lex_attrs):
|
2017-12-07 11:59:23 +03:00
|
|
|
lang_class = get_lang_class(lang)
|
|
|
|
nlp = lang_class()
|
2017-11-27 01:21:47 +03:00
|
|
|
for lexeme in nlp.vocab:
|
|
|
|
lexeme.rank = 0
|
|
|
|
lex_added = 0
|
2018-07-03 13:22:56 +03:00
|
|
|
for attrs in lex_attrs:
|
2018-11-30 22:16:14 +03:00
|
|
|
if "settings" in attrs:
|
2018-07-04 03:29:48 +03:00
|
|
|
continue
|
2018-11-30 22:16:14 +03:00
|
|
|
lexeme = nlp.vocab[attrs["orth"]]
|
2018-07-04 03:29:48 +03:00
|
|
|
lexeme.set_attrs(**attrs)
|
2017-11-27 01:21:47 +03:00
|
|
|
lexeme.is_oov = False
|
|
|
|
lex_added += 1
|
2018-07-03 13:22:56 +03:00
|
|
|
lex_added += 1
|
2019-08-01 18:26:09 +03:00
|
|
|
if len(nlp.vocab):
|
|
|
|
oov_prob = min(lex.prob for lex in nlp.vocab) - 1
|
|
|
|
else:
|
|
|
|
oov_prob = DEFAULT_OOV_PROB
|
|
|
|
nlp.vocab.cfg.update({"oov_prob": oov_prob})
|
2017-11-27 01:21:47 +03:00
|
|
|
return nlp
|
|
|
|
|
2018-11-30 22:16:14 +03:00
|
|
|
|
2018-07-04 03:29:48 +03:00
|
|
|
def add_vectors(nlp, vectors_loc, prune_vectors):
|
|
|
|
vectors_loc = ensure_path(vectors_loc)
|
2018-11-30 22:16:14 +03:00
|
|
|
if vectors_loc and vectors_loc.parts[-1].endswith(".npz"):
|
|
|
|
nlp.vocab.vectors = Vectors(data=numpy.load(vectors_loc.open("rb")))
|
2018-07-04 03:29:48 +03:00
|
|
|
for lex in nlp.vocab:
|
|
|
|
if lex.rank:
|
|
|
|
nlp.vocab.vectors.add(lex.orth, row=lex.rank)
|
|
|
|
else:
|
2018-11-30 22:16:14 +03:00
|
|
|
if vectors_loc:
|
|
|
|
with msg.loading("Reading vectors from {}".format(vectors_loc)):
|
|
|
|
vectors_data, vector_keys = read_vectors(vectors_loc)
|
|
|
|
msg.good("Loaded vectors from {}".format(vectors_loc))
|
|
|
|
else:
|
|
|
|
vectors_data, vector_keys = (None, None)
|
2018-07-04 03:29:48 +03:00
|
|
|
if vector_keys is not None:
|
|
|
|
for word in vector_keys:
|
|
|
|
if word not in nlp.vocab:
|
|
|
|
lexeme = nlp.vocab[word]
|
|
|
|
lexeme.is_oov = False
|
|
|
|
if vectors_data is not None:
|
|
|
|
nlp.vocab.vectors = Vectors(data=vectors_data, keys=vector_keys)
|
2018-11-30 22:16:14 +03:00
|
|
|
nlp.vocab.vectors.name = "%s_model.vectors" % nlp.meta["lang"]
|
|
|
|
nlp.meta["vectors"]["name"] = nlp.vocab.vectors.name
|
2018-07-04 03:29:48 +03:00
|
|
|
if prune_vectors >= 1:
|
|
|
|
nlp.vocab.prune_vectors(prune_vectors)
|
2017-11-27 01:21:47 +03:00
|
|
|
|
2018-11-30 22:16:14 +03:00
|
|
|
|
2017-11-27 01:21:47 +03:00
|
|
|
def read_vectors(vectors_loc):
|
2018-03-21 16:33:23 +03:00
|
|
|
f = open_file(vectors_loc)
|
|
|
|
shape = tuple(int(size) for size in next(f).split())
|
2018-11-30 22:16:14 +03:00
|
|
|
vectors_data = numpy.zeros(shape=shape, dtype="f")
|
2018-03-21 16:33:23 +03:00
|
|
|
vectors_keys = []
|
|
|
|
for i, line in enumerate(tqdm(f)):
|
2018-03-28 00:01:18 +03:00
|
|
|
line = line.rstrip()
|
2019-05-07 00:00:38 +03:00
|
|
|
pieces = line.rsplit(" ", vectors_data.shape[1])
|
2018-03-21 16:33:23 +03:00
|
|
|
word = pieces.pop(0)
|
2018-03-28 00:01:18 +03:00
|
|
|
if len(pieces) != vectors_data.shape[1]:
|
2018-11-30 22:16:14 +03:00
|
|
|
msg.fail(Errors.E094.format(line_num=i, loc=vectors_loc), exits=1)
|
|
|
|
vectors_data[i] = numpy.asarray(pieces, dtype="f")
|
2018-03-21 16:33:23 +03:00
|
|
|
vectors_keys.append(word)
|
2017-11-27 01:21:47 +03:00
|
|
|
return vectors_data, vectors_keys
|
|
|
|
|
|
|
|
|
|
|
|
def read_freqs(freqs_loc, max_length=100, min_doc_freq=5, min_freq=50):
|
|
|
|
counts = PreshCounter()
|
|
|
|
total = 0
|
|
|
|
with freqs_loc.open() as f:
|
|
|
|
for i, line in enumerate(f):
|
2018-11-30 22:16:14 +03:00
|
|
|
freq, doc_freq, key = line.rstrip().split("\t", 2)
|
2017-11-27 01:21:47 +03:00
|
|
|
freq = int(freq)
|
|
|
|
counts.inc(i + 1, freq)
|
|
|
|
total += freq
|
|
|
|
counts.smooth()
|
|
|
|
log_total = math.log(total)
|
|
|
|
probs = {}
|
|
|
|
with freqs_loc.open() as f:
|
|
|
|
for line in tqdm(f):
|
2018-11-30 22:16:14 +03:00
|
|
|
freq, doc_freq, key = line.rstrip().split("\t", 2)
|
2017-11-27 01:21:47 +03:00
|
|
|
doc_freq = int(doc_freq)
|
|
|
|
freq = int(freq)
|
|
|
|
if doc_freq >= min_doc_freq and freq >= min_freq and len(key) < max_length:
|
2019-01-15 01:48:30 +03:00
|
|
|
try:
|
|
|
|
word = literal_eval(key)
|
|
|
|
except SyntaxError:
|
|
|
|
# Take odd strings literally.
|
|
|
|
word = literal_eval("'%s'" % key)
|
2017-11-27 01:21:47 +03:00
|
|
|
smooth_count = counts.smoother(int(freq))
|
|
|
|
probs[word] = math.log(smooth_count) - log_total
|
|
|
|
oov_prob = math.log(counts.smoother(0)) - log_total
|
|
|
|
return probs, oov_prob
|
|
|
|
|
|
|
|
|
|
|
|
def read_clusters(clusters_loc):
|
|
|
|
clusters = {}
|
2018-04-10 20:08:06 +03:00
|
|
|
if ftfy is None:
|
|
|
|
user_warning(Warnings.W004)
|
2017-11-27 01:21:47 +03:00
|
|
|
with clusters_loc.open() as f:
|
|
|
|
for line in tqdm(f):
|
|
|
|
try:
|
|
|
|
cluster, word, freq = line.split()
|
2018-04-10 20:08:06 +03:00
|
|
|
if ftfy is not None:
|
|
|
|
word = ftfy.fix_text(word)
|
2017-11-27 01:21:47 +03:00
|
|
|
except ValueError:
|
|
|
|
continue
|
|
|
|
# If the clusterer has only seen the word a few times, its
|
|
|
|
# cluster is unreliable.
|
|
|
|
if int(freq) >= 3:
|
|
|
|
clusters[word] = cluster
|
|
|
|
else:
|
2018-11-30 22:16:14 +03:00
|
|
|
clusters[word] = "0"
|
2017-11-27 01:21:47 +03:00
|
|
|
# Expand clusters with re-casing
|
|
|
|
for word, cluster in list(clusters.items()):
|
|
|
|
if word.lower() not in clusters:
|
|
|
|
clusters[word.lower()] = cluster
|
|
|
|
if word.title() not in clusters:
|
|
|
|
clusters[word.title()] = cluster
|
|
|
|
if word.upper() not in clusters:
|
|
|
|
clusters[word.upper()] = cluster
|
|
|
|
return clusters
|