mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-13 02:36:32 +03:00
Merge remote-tracking branch 'refs/remotes/honnibal/master'
This commit is contained in:
commit
14b89ff1c5
|
@ -11,6 +11,7 @@ import ujson
|
||||||
import codecs
|
import codecs
|
||||||
from preshed.counter import PreshCounter
|
from preshed.counter import PreshCounter
|
||||||
from joblib import Parallel, delayed
|
from joblib import Parallel, delayed
|
||||||
|
import io
|
||||||
|
|
||||||
from spacy.en import English
|
from spacy.en import English
|
||||||
from spacy.strings import StringStore
|
from spacy.strings import StringStore
|
||||||
|
|
273
examples/nn_text_class.py
Normal file
273
examples/nn_text_class.py
Normal file
|
@ -0,0 +1,273 @@
|
||||||
|
from __future__ import unicode_literals
|
||||||
|
from __future__ import print_function
|
||||||
|
from __future__ import division
|
||||||
|
|
||||||
|
from collections import defaultdict
|
||||||
|
from pathlib import Path
|
||||||
|
import numpy
|
||||||
|
import plac
|
||||||
|
|
||||||
|
import spacy.en
|
||||||
|
|
||||||
|
|
||||||
|
def read_data(nlp, data_dir):
|
||||||
|
for subdir, label in (('pos', 1), ('neg', 0)):
|
||||||
|
for filename in (data_dir / subdir).iterdir():
|
||||||
|
text = filename.open().read()
|
||||||
|
doc = nlp(text)
|
||||||
|
if len(doc) >= 1:
|
||||||
|
yield doc, label
|
||||||
|
|
||||||
|
|
||||||
|
def partition(examples, split_size):
|
||||||
|
examples = list(examples)
|
||||||
|
numpy.random.shuffle(examples)
|
||||||
|
n_docs = len(examples)
|
||||||
|
split = int(n_docs * split_size)
|
||||||
|
return examples[:split], examples[split:]
|
||||||
|
|
||||||
|
|
||||||
|
def minibatch(data, bs=24):
|
||||||
|
for i in range(0, len(data), bs):
|
||||||
|
yield data[i:i+bs]
|
||||||
|
|
||||||
|
|
||||||
|
class Extractor(object):
|
||||||
|
def __init__(self, nlp, vector_length, dropout=0.3):
|
||||||
|
self.nlp = nlp
|
||||||
|
self.dropout = dropout
|
||||||
|
self.vector = numpy.zeros((vector_length, ))
|
||||||
|
|
||||||
|
def doc2bow(self, doc, dropout=None):
|
||||||
|
if dropout is None:
|
||||||
|
dropout = self.dropout
|
||||||
|
bow = defaultdict(int)
|
||||||
|
all_words = defaultdict(int)
|
||||||
|
for word in doc:
|
||||||
|
if numpy.random.random() >= dropout and not word.is_punct:
|
||||||
|
bow[word.lower] += 1
|
||||||
|
all_words[word.lower] += 1
|
||||||
|
if sum(bow.values()) >= 1:
|
||||||
|
return bow
|
||||||
|
else:
|
||||||
|
return all_words
|
||||||
|
|
||||||
|
def bow2vec(self, bow, E):
|
||||||
|
self.vector.fill(0)
|
||||||
|
n = 0
|
||||||
|
for orth_id, freq in bow.items():
|
||||||
|
self.vector += self.nlp.vocab[self.nlp.vocab.strings[orth_id]].repvec * freq
|
||||||
|
# Apply the fine-tuning we've learned
|
||||||
|
if orth_id < E.shape[0]:
|
||||||
|
self.vector += E[orth_id] * freq
|
||||||
|
n += freq
|
||||||
|
return self.vector / n
|
||||||
|
|
||||||
|
|
||||||
|
class NeuralNetwork(object):
|
||||||
|
def __init__(self, depth, width, n_classes, n_vocab, extracter, optimizer):
|
||||||
|
self.depth = depth
|
||||||
|
self.width = width
|
||||||
|
self.n_classes = n_classes
|
||||||
|
self.weights = Params.random(depth, width, width, n_classes, n_vocab)
|
||||||
|
self.doc2bow = extracter.doc2bow
|
||||||
|
self.bow2vec = extracter.bow2vec
|
||||||
|
self.optimizer = optimizer
|
||||||
|
self._gradient = Params.zero(depth, width, width, n_classes, n_vocab)
|
||||||
|
self._activity = numpy.zeros((depth, width))
|
||||||
|
|
||||||
|
def train(self, batch):
|
||||||
|
activity = self._activity
|
||||||
|
gradient = self._gradient
|
||||||
|
activity.fill(0)
|
||||||
|
gradient.data.fill(0)
|
||||||
|
loss = 0
|
||||||
|
word_freqs = defaultdict(int)
|
||||||
|
for doc, label in batch:
|
||||||
|
word_ids = self.doc2bow(doc)
|
||||||
|
vector = self.bow2vec(word_ids, self.weights.E)
|
||||||
|
self.forward(activity, vector)
|
||||||
|
loss += self.backprop(vector, gradient, activity, word_ids, label)
|
||||||
|
for w, freq in word_ids.items():
|
||||||
|
word_freqs[w] += freq
|
||||||
|
self.optimizer(self.weights, gradient, len(batch), word_freqs)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def predict(self, doc):
|
||||||
|
actv = self._activity
|
||||||
|
actv.fill(0)
|
||||||
|
W = self.weights.W
|
||||||
|
b = self.weights.b
|
||||||
|
E = self.weights.E
|
||||||
|
|
||||||
|
vector = self.bow2vec(self.doc2bow(doc, dropout=0.0), E)
|
||||||
|
self.forward(actv, vector)
|
||||||
|
return numpy.argmax(softmax(actv[-1], W[-1], b[-1]))
|
||||||
|
|
||||||
|
def forward(self, actv, in_):
|
||||||
|
actv.fill(0)
|
||||||
|
W = self.weights.W; b = self.weights.b
|
||||||
|
actv[0] = relu(in_, W[0], b[0])
|
||||||
|
for i in range(1, self.depth):
|
||||||
|
actv[i] = relu(actv[i-1], W[i], b[i])
|
||||||
|
|
||||||
|
def backprop(self, input_vector, gradient, activity, ids, label):
|
||||||
|
W = self.weights.W
|
||||||
|
b = self.weights.b
|
||||||
|
|
||||||
|
target = numpy.zeros(self.n_classes)
|
||||||
|
target[label] = 1.0
|
||||||
|
pred = softmax(activity[-1], W[-1], b[-1])
|
||||||
|
delta = pred - target
|
||||||
|
|
||||||
|
for i in range(self.depth, 0, -1):
|
||||||
|
gradient.b[i] += delta
|
||||||
|
gradient.W[i] += numpy.outer(delta, activity[i-1])
|
||||||
|
delta = d_relu(activity[i-1]) * W[i].T.dot(delta)
|
||||||
|
|
||||||
|
gradient.b[0] += delta
|
||||||
|
gradient.W[0] += numpy.outer(delta, input_vector)
|
||||||
|
tuning = W[0].T.dot(delta).reshape((self.width,)) / len(ids)
|
||||||
|
for w, freq in ids.items():
|
||||||
|
if w < gradient.E.shape[0]:
|
||||||
|
gradient.E[w] += tuning * freq
|
||||||
|
return -sum(target * numpy.log(pred))
|
||||||
|
|
||||||
|
|
||||||
|
def softmax(actvn, W, b):
|
||||||
|
w = W.dot(actvn) + b
|
||||||
|
ew = numpy.exp(w - max(w))
|
||||||
|
return (ew / sum(ew)).ravel()
|
||||||
|
|
||||||
|
|
||||||
|
def relu(actvn, W, b):
|
||||||
|
x = W.dot(actvn) + b
|
||||||
|
return x * (x > 0)
|
||||||
|
|
||||||
|
|
||||||
|
def d_relu(x):
|
||||||
|
return x > 0
|
||||||
|
|
||||||
|
|
||||||
|
class Adagrad(object):
|
||||||
|
def __init__(self, lr, rho):
|
||||||
|
self.eps = 1e-3
|
||||||
|
# initial learning rate
|
||||||
|
self.learning_rate = lr
|
||||||
|
self.rho = rho
|
||||||
|
# stores sum of squared gradients
|
||||||
|
#self.h = numpy.zeros(self.dim)
|
||||||
|
#self._curr_rate = numpy.zeros(self.h.shape)
|
||||||
|
self.h = None
|
||||||
|
self._curr_rate = None
|
||||||
|
|
||||||
|
def __call__(self, weights, gradient, batch_size, word_freqs):
|
||||||
|
if self.h is None:
|
||||||
|
self.h = numpy.zeros(gradient.data.shape)
|
||||||
|
self._curr_rate = numpy.zeros(gradient.data.shape)
|
||||||
|
self.L2_penalty(gradient, weights, word_freqs)
|
||||||
|
update = self.rescale(gradient.data / batch_size)
|
||||||
|
weights.data -= update
|
||||||
|
|
||||||
|
def rescale(self, gradient):
|
||||||
|
if self.h is None:
|
||||||
|
self.h = numpy.zeros(gradient.data.shape)
|
||||||
|
self._curr_rate = numpy.zeros(gradient.data.shape)
|
||||||
|
self._curr_rate.fill(0)
|
||||||
|
self.h += gradient ** 2
|
||||||
|
self._curr_rate = self.learning_rate / (numpy.sqrt(self.h) + self.eps)
|
||||||
|
return self._curr_rate * gradient
|
||||||
|
|
||||||
|
def L2_penalty(self, gradient, weights, word_freqs):
|
||||||
|
# L2 Regularization
|
||||||
|
for i in range(len(weights.W)):
|
||||||
|
gradient.W[i] += weights.W[i] * self.rho
|
||||||
|
gradient.b[i] += weights.b[i] * self.rho
|
||||||
|
for w, freq in word_freqs.items():
|
||||||
|
if w < gradient.E.shape[0]:
|
||||||
|
gradient.E[w] += weights.E[w] * self.rho
|
||||||
|
|
||||||
|
|
||||||
|
class Params(object):
|
||||||
|
@classmethod
|
||||||
|
def zero(cls, depth, n_embed, n_hidden, n_labels, n_vocab):
|
||||||
|
return cls(depth, n_embed, n_hidden, n_labels, n_vocab, lambda x: numpy.zeros((x,)))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def random(cls, depth, nE, nH, nL, nV):
|
||||||
|
return cls(depth, nE, nH, nL, nV, lambda x: (numpy.random.rand(x) * 2 - 1) * 0.08)
|
||||||
|
|
||||||
|
def __init__(self, depth, n_embed, n_hidden, n_labels, n_vocab, initializer):
|
||||||
|
nE = n_embed; nH = n_hidden; nL = n_labels; nV = n_vocab
|
||||||
|
n_weights = sum([
|
||||||
|
(nE * nH) + nH,
|
||||||
|
(nH * nH + nH) * depth,
|
||||||
|
(nH * nL) + nL,
|
||||||
|
(nV * nE)
|
||||||
|
])
|
||||||
|
self.data = initializer(n_weights)
|
||||||
|
self.W = []
|
||||||
|
self.b = []
|
||||||
|
i = self._add_layer(0, nE, nH)
|
||||||
|
for _ in range(1, depth):
|
||||||
|
i = self._add_layer(i, nH, nH)
|
||||||
|
i = self._add_layer(i, nL, nH)
|
||||||
|
self.E = self.data[i : i + (nV * nE)].reshape((nV, nE))
|
||||||
|
self.E.fill(0)
|
||||||
|
|
||||||
|
def _add_layer(self, start, x, y):
|
||||||
|
end = start + (x * y)
|
||||||
|
self.W.append(self.data[start : end].reshape((x, y)))
|
||||||
|
self.b.append(self.data[end : end + x].reshape((x, )))
|
||||||
|
return end + x
|
||||||
|
|
||||||
|
|
||||||
|
@plac.annotations(
|
||||||
|
data_dir=("Data directory", "positional", None, Path),
|
||||||
|
n_iter=("Number of iterations (epochs)", "option", "i", int),
|
||||||
|
width=("Size of hidden layers", "option", "H", int),
|
||||||
|
depth=("Depth", "option", "d", int),
|
||||||
|
dropout=("Drop-out rate", "option", "r", float),
|
||||||
|
rho=("Regularization penalty", "option", "p", float),
|
||||||
|
eta=("Learning rate", "option", "e", float),
|
||||||
|
batch_size=("Batch size", "option", "b", int),
|
||||||
|
vocab_size=("Number of words to fine-tune", "option", "w", int),
|
||||||
|
)
|
||||||
|
def main(data_dir, depth=3, width=300, n_iter=5, vocab_size=40000,
|
||||||
|
batch_size=24, dropout=0.3, rho=1e-5, eta=0.005):
|
||||||
|
n_classes = 2
|
||||||
|
print("Loading")
|
||||||
|
nlp = spacy.en.English(parser=False)
|
||||||
|
train_data, dev_data = partition(read_data(nlp, data_dir / 'train'), 0.8)
|
||||||
|
print("Begin training")
|
||||||
|
extracter = Extractor(nlp, width, dropout=0.3)
|
||||||
|
optimizer = Adagrad(eta, rho)
|
||||||
|
model = NeuralNetwork(depth, width, n_classes, vocab_size, extracter, optimizer)
|
||||||
|
prev_best = 0
|
||||||
|
best_weights = None
|
||||||
|
for epoch in range(n_iter):
|
||||||
|
numpy.random.shuffle(train_data)
|
||||||
|
train_loss = 0.0
|
||||||
|
for batch in minibatch(train_data, bs=batch_size):
|
||||||
|
train_loss += model.train(batch)
|
||||||
|
n_correct = sum(model.predict(x) == y for x, y in dev_data)
|
||||||
|
print(epoch, train_loss, n_correct / len(dev_data))
|
||||||
|
if n_correct >= prev_best:
|
||||||
|
best_weights = model.weights.data.copy()
|
||||||
|
prev_best = n_correct
|
||||||
|
|
||||||
|
model.weights.data = best_weights
|
||||||
|
print("Evaluating")
|
||||||
|
eval_data = list(read_data(nlp, data_dir / 'test'))
|
||||||
|
n_correct = sum(model.predict(x) == y for x, y in eval_data)
|
||||||
|
print(n_correct / len(eval_data))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
#import cProfile
|
||||||
|
#import pstats
|
||||||
|
#cProfile.runctx("main(Path('data/aclImdb'))", globals(), locals(), "Profile.prof")
|
||||||
|
#s = pstats.Stats("Profile.prof")
|
||||||
|
#s.strip_dirs().sort_stats("time").print_stats(100)
|
||||||
|
plac.call(main)
|
2
fabfile.py
vendored
2
fabfile.py
vendored
|
@ -48,7 +48,7 @@ def prebuild(build_dir='/tmp/build_spacy'):
|
||||||
local('virtualenv ' + build_venv)
|
local('virtualenv ' + build_venv)
|
||||||
with prefix('cd %s && PYTHONPATH=`pwd` && . %s/bin/activate' % (build_dir, build_venv)):
|
with prefix('cd %s && PYTHONPATH=`pwd` && . %s/bin/activate' % (build_dir, build_venv)):
|
||||||
local('pip install cython fabric fabtools pytest')
|
local('pip install cython fabric fabtools pytest')
|
||||||
local('pip install -r requirements.txt')
|
local('pip install --no-cache-dir -r requirements.txt')
|
||||||
local('fab clean make')
|
local('fab clean make')
|
||||||
local('cp -r %s/corpora/en/wordnet corpora/en/' % spacy_dir)
|
local('cp -r %s/corpora/en/wordnet corpora/en/' % spacy_dir)
|
||||||
local('cp %s/corpora/en/freqs.txt.gz corpora/en/' % spacy_dir)
|
local('cp %s/corpora/en/freqs.txt.gz corpora/en/' % spacy_dir)
|
||||||
|
|
|
@ -342,7 +342,7 @@ hardcoded_specials = {
|
||||||
"\n": [{"F": "\n", "pos": "SP"}],
|
"\n": [{"F": "\n", "pos": "SP"}],
|
||||||
"\t": [{"F": "\t", "pos": "SP"}],
|
"\t": [{"F": "\t", "pos": "SP"}],
|
||||||
" ": [{"F": " ", "pos": "SP"}],
|
" ": [{"F": " ", "pos": "SP"}],
|
||||||
u"\xa0": [{"F": u"\xa0", "pos": "SP", "L": " "}]
|
u"\u00a0": [{"F": u"\u00a0", "pos": "SP", "L": " "}]
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
\.\.\.
|
\.\.\.
|
||||||
(?<=[a-z])\.(?=[A-Z])
|
(?<=[a-z])\.(?=[A-Z])
|
||||||
(?<=[a-zA-Z])-(?=[a-zA-z])
|
(?<=[a-zA-Z])-(?=[a-zA-z])
|
||||||
|
(?<=[0-9])-(?=[0-9])
|
||||||
|
|
|
@ -6,7 +6,6 @@ thinc == 3.3
|
||||||
murmurhash == 0.24
|
murmurhash == 0.24
|
||||||
text-unidecode
|
text-unidecode
|
||||||
numpy
|
numpy
|
||||||
wget
|
|
||||||
plac
|
plac
|
||||||
six
|
six
|
||||||
ujson
|
ujson
|
||||||
|
|
9
setup.py
9
setup.py
|
@ -162,7 +162,7 @@ def run_setup(exts):
|
||||||
ext_modules=exts,
|
ext_modules=exts,
|
||||||
license="MIT",
|
license="MIT",
|
||||||
install_requires=['numpy', 'murmurhash', 'cymem >= 1.11', 'preshed >= 0.42',
|
install_requires=['numpy', 'murmurhash', 'cymem >= 1.11', 'preshed >= 0.42',
|
||||||
'thinc == 3.3', "text_unidecode", 'wget', 'plac', 'six',
|
'thinc == 3.3', "text_unidecode", 'plac', 'six',
|
||||||
'ujson', 'cloudpickle'],
|
'ujson', 'cloudpickle'],
|
||||||
setup_requires=["headers_workaround"],
|
setup_requires=["headers_workaround"],
|
||||||
cmdclass = {'build_ext': build_ext_subclass },
|
cmdclass = {'build_ext': build_ext_subclass },
|
||||||
|
@ -175,13 +175,14 @@ def run_setup(exts):
|
||||||
headers_workaround.install_headers('numpy')
|
headers_workaround.install_headers('numpy')
|
||||||
|
|
||||||
|
|
||||||
VERSION = '0.95'
|
VERSION = '0.96'
|
||||||
def main(modules, is_pypy):
|
def main(modules, is_pypy):
|
||||||
language = "cpp"
|
language = "cpp"
|
||||||
includes = ['.', path.join(sys.prefix, 'include')]
|
includes = ['.', path.join(sys.prefix, 'include')]
|
||||||
if sys.platform.startswith('darwin'):
|
if sys.platform.startswith('darwin'):
|
||||||
compile_options['other'].append(['-mmacosx-version-min=10.8', '-stdlib=libc++'])
|
compile_options['other'].append('-mmacosx-version-min=10.8')
|
||||||
link_opions['other'].append('-lc++')
|
compile_options['other'].append('-stdlib=libc++')
|
||||||
|
link_options['other'].append('-lc++')
|
||||||
if use_cython:
|
if use_cython:
|
||||||
cython_setup(modules, language, includes)
|
cython_setup(modules, language, includes)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -1,11 +1,13 @@
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
from os import path
|
from os import path
|
||||||
|
import sys
|
||||||
import os
|
import os
|
||||||
import tarfile
|
import tarfile
|
||||||
import shutil
|
import shutil
|
||||||
import wget
|
|
||||||
import plac
|
import plac
|
||||||
|
|
||||||
|
from . import uget
|
||||||
|
|
||||||
# TODO: Read this from the same source as the setup
|
# TODO: Read this from the same source as the setup
|
||||||
VERSION = '0.9.5'
|
VERSION = '0.9.5'
|
||||||
|
|
||||||
|
@ -13,39 +15,45 @@ AWS_STORE = 'https://s3-us-west-1.amazonaws.com/media.spacynlp.com'
|
||||||
|
|
||||||
ALL_DATA_DIR_URL = '%s/en_data_all-%s.tgz' % (AWS_STORE, VERSION)
|
ALL_DATA_DIR_URL = '%s/en_data_all-%s.tgz' % (AWS_STORE, VERSION)
|
||||||
|
|
||||||
DEST_DIR = path.join(path.dirname(__file__), 'data')
|
DEST_DIR = path.join(path.dirname(path.abspath(__file__)), 'data')
|
||||||
|
|
||||||
def download_file(url, out):
|
|
||||||
wget.download(url, out=out)
|
def download_file(url, dest_dir):
|
||||||
return url.rsplit('/', 1)[1]
|
return uget.download(url, dest_dir, console=sys.stdout)
|
||||||
|
|
||||||
|
|
||||||
def install_data(url, dest_dir):
|
def install_data(url, dest_dir):
|
||||||
filename = download_file(url, dest_dir)
|
filename = download_file(url, dest_dir)
|
||||||
t = tarfile.open(path.join(dest_dir, filename))
|
t = tarfile.open(filename)
|
||||||
t.extractall(dest_dir)
|
t.extractall(dest_dir)
|
||||||
|
|
||||||
|
|
||||||
def install_parser_model(url, dest_dir):
|
def install_parser_model(url, dest_dir):
|
||||||
filename = download_file(url, dest_dir)
|
filename = download_file(url, dest_dir)
|
||||||
t = tarfile.open(path.join(dest_dir, filename), mode=":gz")
|
t = tarfile.open(filename, mode=":gz")
|
||||||
t.extractall(path.dirname(__file__))
|
t.extractall(dest_dir)
|
||||||
|
|
||||||
|
|
||||||
def install_dep_vectors(url, dest_dir):
|
def install_dep_vectors(url, dest_dir):
|
||||||
if not os.path.exists(dest_dir):
|
download_file(url, dest_dir)
|
||||||
os.mkdir(dest_dir)
|
|
||||||
|
|
||||||
filename = download_file(url, dest_dir)
|
|
||||||
|
|
||||||
|
|
||||||
def main(data_size='all'):
|
@plac.annotations(
|
||||||
|
force=("Force overwrite", "flag", "f", bool),
|
||||||
|
)
|
||||||
|
def main(data_size='all', force=False):
|
||||||
if data_size == 'all':
|
if data_size == 'all':
|
||||||
data_url = ALL_DATA_DIR_URL
|
data_url = ALL_DATA_DIR_URL
|
||||||
elif data_size == 'small':
|
elif data_size == 'small':
|
||||||
data_url = SM_DATA_DIR_URL
|
data_url = SM_DATA_DIR_URL
|
||||||
if path.exists(DEST_DIR):
|
|
||||||
|
if force and path.exists(DEST_DIR):
|
||||||
shutil.rmtree(DEST_DIR)
|
shutil.rmtree(DEST_DIR)
|
||||||
install_data(data_url, path.dirname(DEST_DIR))
|
|
||||||
|
if not os.path.exists(DEST_DIR):
|
||||||
|
os.makedirs(DEST_DIR)
|
||||||
|
|
||||||
|
install_data(data_url, DEST_DIR)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
246
spacy/en/uget.py
Normal file
246
spacy/en/uget.py
Normal file
|
@ -0,0 +1,246 @@
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import io
|
||||||
|
import math
|
||||||
|
import re
|
||||||
|
|
||||||
|
try:
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
from urllib.request import urlopen, Request
|
||||||
|
from urllib.error import HTTPError
|
||||||
|
except ImportError:
|
||||||
|
from urllib2 import urlopen, urlparse, Request, HTTPError
|
||||||
|
|
||||||
|
|
||||||
|
class UnknownContentLengthException(Exception): pass
|
||||||
|
class InvalidChecksumException(Exception): pass
|
||||||
|
class UnsupportedHTTPCodeException(Exception): pass
|
||||||
|
class InvalidOffsetException(Exception): pass
|
||||||
|
class MissingChecksumHeader(Exception): pass
|
||||||
|
|
||||||
|
|
||||||
|
CHUNK_SIZE = 16 * 1024
|
||||||
|
|
||||||
|
|
||||||
|
class RateSampler(object):
|
||||||
|
def __init__(self, period=1):
|
||||||
|
self.rate = None
|
||||||
|
self.reset = True
|
||||||
|
self.period = period
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
if self.reset:
|
||||||
|
self.reset = False
|
||||||
|
self.start = time.time()
|
||||||
|
self.counter = 0
|
||||||
|
|
||||||
|
def __exit__(self, type, value, traceback):
|
||||||
|
elapsed = time.time() - self.start
|
||||||
|
if elapsed >= self.period:
|
||||||
|
self.reset = True
|
||||||
|
self.rate = float(self.counter) / elapsed
|
||||||
|
|
||||||
|
def update(self, value):
|
||||||
|
self.counter += value
|
||||||
|
|
||||||
|
def format(self, unit="MB"):
|
||||||
|
if self.rate is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
divisor = {'MB': 1048576, 'kB': 1024}
|
||||||
|
return "%0.2f%s/s" % (self.rate / divisor[unit], unit)
|
||||||
|
|
||||||
|
|
||||||
|
class TimeEstimator(object):
|
||||||
|
def __init__(self, cooldown=1):
|
||||||
|
self.cooldown = cooldown
|
||||||
|
self.start = time.time()
|
||||||
|
self.time_left = None
|
||||||
|
|
||||||
|
def update(self, bytes_read, total_size):
|
||||||
|
elapsed = time.time() - self.start
|
||||||
|
if elapsed > self.cooldown:
|
||||||
|
self.time_left = math.ceil(elapsed * total_size /
|
||||||
|
bytes_read - elapsed)
|
||||||
|
|
||||||
|
def format(self):
|
||||||
|
if self.time_left is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
res = "eta "
|
||||||
|
if self.time_left / 60 >= 1:
|
||||||
|
res += "%dm " % (self.time_left / 60)
|
||||||
|
return res + "%ds" % (self.time_left % 60)
|
||||||
|
|
||||||
|
|
||||||
|
def format_bytes_read(bytes_read, unit="MB"):
|
||||||
|
divisor = {'MB': 1048576, 'kB': 1024}
|
||||||
|
return "%0.2f%s" % (float(bytes_read) / divisor[unit], unit)
|
||||||
|
|
||||||
|
|
||||||
|
def format_percent(bytes_read, total_size):
|
||||||
|
percent = round(bytes_read * 100.0 / total_size, 2)
|
||||||
|
return "%0.2f%%" % percent
|
||||||
|
|
||||||
|
|
||||||
|
def get_content_range(response):
|
||||||
|
content_range = response.headers.get('Content-Range', "").strip()
|
||||||
|
if content_range:
|
||||||
|
m = re.match(r"bytes (\d+)-(\d+)/(\d+)", content_range)
|
||||||
|
if m:
|
||||||
|
return [int(v) for v in m.groups()]
|
||||||
|
|
||||||
|
|
||||||
|
def get_content_length(response):
|
||||||
|
if 'Content-Length' not in response.headers:
|
||||||
|
raise UnknownContentLengthException
|
||||||
|
return int(response.headers.get('Content-Length').strip())
|
||||||
|
|
||||||
|
|
||||||
|
def get_url_meta(url, checksum_header=None):
|
||||||
|
class HeadRequest(Request):
|
||||||
|
def get_method(self):
|
||||||
|
return "HEAD"
|
||||||
|
|
||||||
|
r = urlopen(HeadRequest(url))
|
||||||
|
res = {'size': get_content_length(r)}
|
||||||
|
|
||||||
|
if checksum_header:
|
||||||
|
value = r.headers.get(checksum_header)
|
||||||
|
if value:
|
||||||
|
res['checksum'] = value
|
||||||
|
|
||||||
|
r.close()
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def progress(console, bytes_read, total_size, transfer_rate, eta):
|
||||||
|
fields = [
|
||||||
|
format_bytes_read(bytes_read),
|
||||||
|
format_percent(bytes_read, total_size),
|
||||||
|
transfer_rate.format(),
|
||||||
|
eta.format(),
|
||||||
|
" " * 10,
|
||||||
|
]
|
||||||
|
console.write("Downloaded %s\r" % " ".join(filter(None, fields)))
|
||||||
|
console.flush()
|
||||||
|
|
||||||
|
|
||||||
|
def read_request(request, offset=0, console=None,
|
||||||
|
progress_func=None, write_func=None):
|
||||||
|
# support partial downloads
|
||||||
|
if offset > 0:
|
||||||
|
request.add_header('Range', "bytes=%s-" % offset)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = urlopen(request)
|
||||||
|
except HTTPError as e:
|
||||||
|
if e.code == 416: # Requested Range Not Satisfiable
|
||||||
|
raise InvalidOffsetException
|
||||||
|
|
||||||
|
# TODO add http error handling here
|
||||||
|
raise UnsupportedHTTPCodeException(e.code)
|
||||||
|
|
||||||
|
total_size = get_content_length(response) + offset
|
||||||
|
bytes_read = offset
|
||||||
|
|
||||||
|
# sanity checks
|
||||||
|
if response.code == 200: # OK
|
||||||
|
assert offset == 0
|
||||||
|
elif response.code == 206: # Partial content
|
||||||
|
range_start, range_end, range_total = get_content_range(response)
|
||||||
|
assert range_start == offset
|
||||||
|
assert range_total == total_size
|
||||||
|
assert range_end + 1 - range_start == total_size - bytes_read
|
||||||
|
else:
|
||||||
|
raise UnsupportedHTTPCodeException(response.code)
|
||||||
|
|
||||||
|
eta = TimeEstimator()
|
||||||
|
transfer_rate = RateSampler()
|
||||||
|
|
||||||
|
if console:
|
||||||
|
if offset > 0:
|
||||||
|
console.write("Continue downloading...\n")
|
||||||
|
else:
|
||||||
|
console.write("Downloading...\n")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
with transfer_rate:
|
||||||
|
chunk = response.read(CHUNK_SIZE)
|
||||||
|
if not chunk:
|
||||||
|
if progress_func and console:
|
||||||
|
console.write('\n')
|
||||||
|
break
|
||||||
|
|
||||||
|
bytes_read += len(chunk)
|
||||||
|
|
||||||
|
transfer_rate.update(len(chunk))
|
||||||
|
eta.update(bytes_read - offset, total_size - offset)
|
||||||
|
|
||||||
|
if progress_func and console:
|
||||||
|
progress_func(console, bytes_read, total_size, transfer_rate, eta)
|
||||||
|
|
||||||
|
if write_func:
|
||||||
|
write_func(chunk)
|
||||||
|
|
||||||
|
response.close()
|
||||||
|
assert bytes_read == total_size
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def download(url, path=".",
|
||||||
|
checksum=None, checksum_header=None,
|
||||||
|
headers=None, console=None):
|
||||||
|
|
||||||
|
if os.path.isdir(path):
|
||||||
|
path = os.path.join(path, url.rsplit('/', 1)[1])
|
||||||
|
path = os.path.abspath(path)
|
||||||
|
|
||||||
|
with io.open(path, "a+b") as f:
|
||||||
|
size = f.tell()
|
||||||
|
|
||||||
|
# update checksum of partially downloaded file
|
||||||
|
if checksum:
|
||||||
|
f.seek(0, os.SEEK_SET)
|
||||||
|
for chunk in iter(lambda: f.read(CHUNK_SIZE), b""):
|
||||||
|
checksum.update(chunk)
|
||||||
|
|
||||||
|
def write(chunk):
|
||||||
|
if checksum:
|
||||||
|
checksum.update(chunk)
|
||||||
|
f.write(chunk)
|
||||||
|
|
||||||
|
request = Request(url)
|
||||||
|
|
||||||
|
# request headers
|
||||||
|
if headers:
|
||||||
|
for key, value in headers.items():
|
||||||
|
request.add_header(key, value)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = read_request(request,
|
||||||
|
offset=size,
|
||||||
|
console=console,
|
||||||
|
progress_func=progress,
|
||||||
|
write_func=write)
|
||||||
|
except InvalidOffsetException:
|
||||||
|
response = None
|
||||||
|
|
||||||
|
if checksum:
|
||||||
|
if response:
|
||||||
|
origin_checksum = response.headers.get(checksum_header)
|
||||||
|
else:
|
||||||
|
# check whether file is already complete
|
||||||
|
meta = get_url_meta(url, checksum_header)
|
||||||
|
origin_checksum = meta.get('checksum')
|
||||||
|
|
||||||
|
if origin_checksum is None:
|
||||||
|
raise MissingChecksumHeader
|
||||||
|
|
||||||
|
if checksum.hexdigest() != origin_checksum:
|
||||||
|
raise InvalidChecksumException
|
||||||
|
|
||||||
|
if console:
|
||||||
|
console.write("checksum/sha256 OK\n")
|
||||||
|
|
||||||
|
return path
|
|
@ -20,8 +20,6 @@ from .tokens.doc cimport get_token_attr
|
||||||
from .tokens.doc cimport Doc
|
from .tokens.doc cimport Doc
|
||||||
from .vocab cimport Vocab
|
from .vocab cimport Vocab
|
||||||
|
|
||||||
from libcpp.vector cimport vector
|
|
||||||
|
|
||||||
from .attrs import FLAG61 as U_ENT
|
from .attrs import FLAG61 as U_ENT
|
||||||
|
|
||||||
from .attrs import FLAG60 as B2_ENT
|
from .attrs import FLAG60 as B2_ENT
|
||||||
|
@ -221,8 +219,7 @@ cdef class Matcher:
|
||||||
q = 0
|
q = 0
|
||||||
# Go over the open matches, extending or finalizing if able. Otherwise,
|
# Go over the open matches, extending or finalizing if able. Otherwise,
|
||||||
# we over-write them (q doesn't advance)
|
# we over-write them (q doesn't advance)
|
||||||
for i in range(partials.size()):
|
for state in partials:
|
||||||
state = partials.at(i)
|
|
||||||
if match(state, token):
|
if match(state, token):
|
||||||
if is_final(state):
|
if is_final(state):
|
||||||
label, start, end = get_entity(state, token, token_i)
|
label, start, end = get_entity(state, token, token_i)
|
||||||
|
@ -233,8 +230,7 @@ cdef class Matcher:
|
||||||
q += 1
|
q += 1
|
||||||
partials.resize(q)
|
partials.resize(q)
|
||||||
# Check whether we open any new patterns on this token
|
# Check whether we open any new patterns on this token
|
||||||
for i in range(self.n_patterns):
|
for state in self.patterns:
|
||||||
state = self.patterns[i]
|
|
||||||
if match(state, token):
|
if match(state, token):
|
||||||
if is_final(state):
|
if is_final(state):
|
||||||
label, start, end = get_entity(state, token, token_i)
|
label, start, end = get_entity(state, token, token_i)
|
||||||
|
@ -242,7 +238,16 @@ cdef class Matcher:
|
||||||
matches.append((label, start, end))
|
matches.append((label, start, end))
|
||||||
else:
|
else:
|
||||||
partials.push_back(state + 1)
|
partials.push_back(state + 1)
|
||||||
doc.ents = [(e.label, e.start, e.end) for e in doc.ents] + matches
|
seen = set()
|
||||||
|
filtered = []
|
||||||
|
for label, start, end in sorted(matches, key=lambda m: (m[1], -(m[1] - m[2]))):
|
||||||
|
if all(i in seen for i in range(start, end)):
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
for i in range(start, end):
|
||||||
|
seen.add(i)
|
||||||
|
filtered.append((label, start, end))
|
||||||
|
doc.ents = [(e.label, e.start, e.end) for e in doc.ents] + filtered
|
||||||
return matches
|
return matches
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -72,6 +72,10 @@ cdef class Tokenizer:
|
||||||
Returns:
|
Returns:
|
||||||
tokens (Doc): A Doc object, giving access to a sequence of LexemeCs.
|
tokens (Doc): A Doc object, giving access to a sequence of LexemeCs.
|
||||||
"""
|
"""
|
||||||
|
if len(string) >= (2 ** 30):
|
||||||
|
raise ValueError(
|
||||||
|
"String is too long: %d characters. Max is 2**30." % len(string)
|
||||||
|
)
|
||||||
cdef int length = len(string)
|
cdef int length = len(string)
|
||||||
cdef Doc tokens = Doc(self.vocab)
|
cdef Doc tokens = Doc(self.vocab)
|
||||||
if length == 0:
|
if length == 0:
|
||||||
|
|
|
@ -447,9 +447,9 @@ cdef class Doc:
|
||||||
|
|
||||||
cdef Span span = self[start:end]
|
cdef Span span = self[start:end]
|
||||||
# Get LexemeC for newly merged token
|
# Get LexemeC for newly merged token
|
||||||
new_orth = ''.join([t.string for t in span])
|
new_orth = ''.join([t.text_with_ws for t in span])
|
||||||
if span[-1].whitespace_:
|
if span[-1].whitespace_:
|
||||||
new_orth = new_orth[:-1]
|
new_orth = new_orth[:-len(span[-1].whitespace_)]
|
||||||
cdef const LexemeC* lex = self.vocab.get(self.mem, new_orth)
|
cdef const LexemeC* lex = self.vocab.get(self.mem, new_orth)
|
||||||
# House the new merged token where it starts
|
# House the new merged token where it starts
|
||||||
cdef TokenC* token = &self.data[start]
|
cdef TokenC* token = &self.data[start]
|
||||||
|
@ -508,16 +508,26 @@ cdef int set_children_from_heads(TokenC* tokens, int length) except -1:
|
||||||
cdef TokenC* head
|
cdef TokenC* head
|
||||||
cdef TokenC* child
|
cdef TokenC* child
|
||||||
cdef int i
|
cdef int i
|
||||||
|
# Set number of left/right children to 0. We'll increment it in the loops.
|
||||||
|
for i in range(length):
|
||||||
|
tokens[i].l_kids = 0
|
||||||
|
tokens[i].r_kids = 0
|
||||||
|
tokens[i].l_edge = i
|
||||||
|
tokens[i].r_edge = i
|
||||||
# Set left edges
|
# Set left edges
|
||||||
for i in range(length):
|
for i in range(length):
|
||||||
child = &tokens[i]
|
child = &tokens[i]
|
||||||
head = &tokens[i + child.head]
|
head = &tokens[i + child.head]
|
||||||
if child < head and child.l_edge < head.l_edge:
|
if child < head:
|
||||||
head.l_edge = child.l_edge
|
if child.l_edge < head.l_edge:
|
||||||
|
head.l_edge = child.l_edge
|
||||||
|
head.l_kids += 1
|
||||||
|
|
||||||
# Set right edges --- same as above, but iterate in reverse
|
# Set right edges --- same as above, but iterate in reverse
|
||||||
for i in range(length-1, -1, -1):
|
for i in range(length-1, -1, -1):
|
||||||
child = &tokens[i]
|
child = &tokens[i]
|
||||||
head = &tokens[i + child.head]
|
head = &tokens[i + child.head]
|
||||||
if child > head and child.r_edge > head.r_edge:
|
if child > head:
|
||||||
head.r_edge = child.r_edge
|
if child.r_edge > head.r_edge:
|
||||||
|
head.r_edge = child.r_edge
|
||||||
|
head.r_kids += 1
|
||||||
|
|
|
@ -278,7 +278,7 @@ cdef class Token:
|
||||||
|
|
||||||
property whitespace_:
|
property whitespace_:
|
||||||
def __get__(self):
|
def __get__(self):
|
||||||
return self.string[self.c.lex.length:]
|
return ' ' if self.c.spacy else ''
|
||||||
|
|
||||||
property orth_:
|
property orth_:
|
||||||
def __get__(self):
|
def __get__(self):
|
||||||
|
|
|
@ -1,17 +1,102 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
from spacy.matcher import Matcher
|
from spacy.matcher import Matcher
|
||||||
|
from spacy.attrs import LOWER
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.xfail
|
|
||||||
def test_overlap_issue118(EN):
|
def test_overlap_issue118(EN):
|
||||||
'''Test a bug that arose from having overlapping matches'''
|
'''Test a bug that arose from having overlapping matches'''
|
||||||
doc = EN.tokenizer(u'how many points did lebron james score against the boston celtics last night')
|
doc = EN.tokenizer(u'how many points did lebron james score against the boston celtics last night')
|
||||||
ORG = doc.vocab.strings['ORG']
|
ORG = doc.vocab.strings['ORG']
|
||||||
matcher = Matcher(EN.vocab, {'BostonCeltics': ('ORG', {}, [[{'lower': 'boston'}, {'lower': 'celtics'}], [{'lower': 'celtics'}]])})
|
matcher = Matcher(EN.vocab,
|
||||||
|
{'BostonCeltics':
|
||||||
|
('ORG', {},
|
||||||
|
[
|
||||||
|
[{LOWER: 'celtics'}],
|
||||||
|
[{LOWER: 'boston'}, {LOWER: 'celtics'}],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(list(doc.ents)) == 0
|
||||||
matches = matcher(doc)
|
matches = matcher(doc)
|
||||||
assert matches == [(ORG, 9, 11)]
|
assert matches == [(ORG, 9, 11), (ORG, 10, 11)]
|
||||||
|
ents = list(doc.ents)
|
||||||
|
assert len(ents) == 1
|
||||||
|
assert ents[0].label == ORG
|
||||||
|
assert ents[0].start == 9
|
||||||
|
assert ents[0].end == 11
|
||||||
|
|
||||||
|
|
||||||
|
def test_overlap_reorder(EN):
|
||||||
|
'''Test order dependence'''
|
||||||
|
doc = EN.tokenizer(u'how many points did lebron james score against the boston celtics last night')
|
||||||
|
ORG = doc.vocab.strings['ORG']
|
||||||
|
matcher = Matcher(EN.vocab,
|
||||||
|
{'BostonCeltics':
|
||||||
|
('ORG', {},
|
||||||
|
[
|
||||||
|
[{LOWER: 'boston'}, {LOWER: 'celtics'}],
|
||||||
|
[{LOWER: 'celtics'}],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(list(doc.ents)) == 0
|
||||||
|
matches = matcher(doc)
|
||||||
|
assert matches == [(ORG, 9, 11), (ORG, 10, 11)]
|
||||||
|
ents = list(doc.ents)
|
||||||
|
assert len(ents) == 1
|
||||||
|
assert ents[0].label == ORG
|
||||||
|
assert ents[0].start == 9
|
||||||
|
assert ents[0].end == 11
|
||||||
|
|
||||||
|
|
||||||
|
def test_overlap_prefix(EN):
|
||||||
|
'''Test order dependence'''
|
||||||
|
doc = EN.tokenizer(u'how many points did lebron james score against the boston celtics last night')
|
||||||
|
ORG = doc.vocab.strings['ORG']
|
||||||
|
matcher = Matcher(EN.vocab,
|
||||||
|
{'BostonCeltics':
|
||||||
|
('ORG', {},
|
||||||
|
[
|
||||||
|
[{LOWER: 'boston'}],
|
||||||
|
[{LOWER: 'boston'}, {LOWER: 'celtics'}],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(list(doc.ents)) == 0
|
||||||
|
matches = matcher(doc)
|
||||||
|
assert matches == [(ORG, 9, 10), (ORG, 9, 11)]
|
||||||
|
ents = list(doc.ents)
|
||||||
|
assert len(ents) == 1
|
||||||
|
assert ents[0].label == ORG
|
||||||
|
assert ents[0].start == 9
|
||||||
|
assert ents[0].end == 11
|
||||||
|
|
||||||
|
|
||||||
|
def test_overlap_prefix_reorder(EN):
|
||||||
|
'''Test order dependence'''
|
||||||
|
doc = EN.tokenizer(u'how many points did lebron james score against the boston celtics last night')
|
||||||
|
ORG = doc.vocab.strings['ORG']
|
||||||
|
matcher = Matcher(EN.vocab,
|
||||||
|
{'BostonCeltics':
|
||||||
|
('ORG', {},
|
||||||
|
[
|
||||||
|
[{LOWER: 'boston'}, {LOWER: 'celtics'}],
|
||||||
|
[{LOWER: 'boston'}],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(list(doc.ents)) == 0
|
||||||
|
matches = matcher(doc)
|
||||||
|
assert matches == [(ORG, 9, 10), (ORG, 9, 11)]
|
||||||
ents = list(doc.ents)
|
ents = list(doc.ents)
|
||||||
assert len(ents) == 1
|
assert len(ents) == 1
|
||||||
assert ents[0].label == ORG
|
assert ents[0].label == ORG
|
||||||
|
|
|
@ -7,6 +7,10 @@ def test_hyphen(en_tokenizer):
|
||||||
assert len(tokens) == 3
|
assert len(tokens) == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_numeric_range(en_tokenizer):
|
||||||
|
tokens = en_tokenizer('0.1-13.5')
|
||||||
|
assert len(tokens) == 3
|
||||||
|
|
||||||
def test_period(en_tokenizer):
|
def test_period(en_tokenizer):
|
||||||
tokens = en_tokenizer('best.Known')
|
tokens = en_tokenizer('best.Known')
|
||||||
assert len(tokens) == 3
|
assert len(tokens) == 3
|
||||||
|
|
|
@ -109,3 +109,42 @@ def test_set_ents(EN):
|
||||||
assert ent.label_ == 'PRODUCT'
|
assert ent.label_ == 'PRODUCT'
|
||||||
assert ent.start == 2
|
assert ent.start == 2
|
||||||
assert ent.end == 4
|
assert ent.end == 4
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge(EN):
|
||||||
|
doc = EN('WKRO played songs by the beach boys all night')
|
||||||
|
|
||||||
|
assert len(doc) == 9
|
||||||
|
# merge 'The Beach Boys'
|
||||||
|
doc.merge(doc[4].idx, doc[6].idx + len(doc[6]), 'NAMED', 'LEMMA', 'TYPE')
|
||||||
|
assert len(doc) == 7
|
||||||
|
|
||||||
|
assert doc[4].text == 'the beach boys'
|
||||||
|
assert doc[4].text_with_ws == 'the beach boys '
|
||||||
|
assert doc[4].tag_ == 'NAMED'
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_end_string(EN):
|
||||||
|
doc = EN('WKRO played songs by the beach boys all night')
|
||||||
|
|
||||||
|
assert len(doc) == 9
|
||||||
|
# merge 'The Beach Boys'
|
||||||
|
doc.merge(doc[7].idx, doc[8].idx + len(doc[8]), 'NAMED', 'LEMMA', 'TYPE')
|
||||||
|
assert len(doc) == 8
|
||||||
|
|
||||||
|
assert doc[7].text == 'all night'
|
||||||
|
assert doc[7].text_with_ws == 'all night'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.models
|
||||||
|
def test_merge_children(EN):
|
||||||
|
"""Test that attachments work correctly after merging."""
|
||||||
|
doc = EN('WKRO played songs by the beach boys all night')
|
||||||
|
# merge 'The Beach Boys'
|
||||||
|
doc.merge(doc[4].idx, doc[6].idx + len(doc[6]), 'NAMED', 'LEMMA', 'TYPE')
|
||||||
|
|
||||||
|
for word in doc:
|
||||||
|
if word.i < word.head.i:
|
||||||
|
assert word in list(word.head.lefts)
|
||||||
|
elif word.i > word.head.i:
|
||||||
|
assert word in list(word.head.rights)
|
||||||
|
|
|
@ -1,8 +1,11 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
import sys
|
from __future__ import unicode_literals
|
||||||
import re
|
|
||||||
import os
|
import os
|
||||||
import ast
|
import ast
|
||||||
|
import io
|
||||||
|
|
||||||
|
import plac
|
||||||
|
|
||||||
# cgi.escape is deprecated since py32
|
# cgi.escape is deprecated since py32
|
||||||
try:
|
try:
|
||||||
|
@ -11,55 +14,62 @@ except ImportError:
|
||||||
from cgi import escape
|
from cgi import escape
|
||||||
|
|
||||||
|
|
||||||
src_dirname = sys.argv[1]
|
# e.g. python website/create_code_samples tests/website/ website/src/
|
||||||
dst_dirname = sys.argv[2]
|
def main(src_dirname, dst_dirname):
|
||||||
prefix = "test_"
|
prefix = "test_"
|
||||||
|
|
||||||
|
for filename in os.listdir(src_dirname):
|
||||||
|
if not filename.startswith('test_'):
|
||||||
|
continue
|
||||||
|
if not filename.endswith('.py'):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Remove test_ prefix and .py suffix
|
||||||
|
name = filename[6:-3]
|
||||||
|
with io.open(os.path.join(src_dirname, filename), 'r', encoding='utf8') as file_:
|
||||||
|
source = file_.readlines()
|
||||||
|
tree = ast.parse("".join(source))
|
||||||
|
|
||||||
|
for root in tree.body:
|
||||||
|
if isinstance(root, ast.FunctionDef) and root.name.startswith(prefix):
|
||||||
|
|
||||||
|
# only ast.expr and ast.stmt have line numbers, see:
|
||||||
|
# https://docs.python.org/2/library/ast.html#ast.AST.lineno
|
||||||
|
line_numbers = []
|
||||||
|
|
||||||
|
for node in ast.walk(root):
|
||||||
|
if hasattr(node, "lineno"):
|
||||||
|
line_numbers.append(node.lineno)
|
||||||
|
|
||||||
|
body = source[min(line_numbers)-1:max(line_numbers)]
|
||||||
|
while not body[0][0].isspace():
|
||||||
|
body = body[1:]
|
||||||
|
|
||||||
|
# make sure we are inside an indented function body
|
||||||
|
assert all([l[0].isspace() for l in body])
|
||||||
|
|
||||||
|
offset = 0
|
||||||
|
for line in body:
|
||||||
|
match = re.search(r"[^\s]", line)
|
||||||
|
if match:
|
||||||
|
offset = match.start(0)
|
||||||
|
break
|
||||||
|
|
||||||
|
# remove indentation
|
||||||
|
assert offset > 0
|
||||||
|
|
||||||
|
for i in range(len(body)):
|
||||||
|
body[i] = body[i][offset:] if len(body[i]) > offset else "\n"
|
||||||
|
|
||||||
|
# make sure empty lines contain a newline
|
||||||
|
assert all([l[-1] == "\n" for l in body])
|
||||||
|
|
||||||
|
code_filename = "%s.%s" % (name, root.name[len(prefix):])
|
||||||
|
|
||||||
|
with io.open(os.path.join(dst_dirname, code_filename),
|
||||||
|
"w", encoding='utf8') as f:
|
||||||
|
f.write(escape("".join(body)))
|
||||||
|
|
||||||
|
|
||||||
for filename in os.listdir(src_dirname):
|
if __name__ == '__main__':
|
||||||
match = re.match(re.escape(prefix) + r"(.+)\.py$", filename)
|
plac.call(main)
|
||||||
if not match:
|
|
||||||
continue
|
|
||||||
|
|
||||||
name = match.group(1)
|
|
||||||
source = open(os.path.join(src_dirname, filename)).readlines()
|
|
||||||
tree = ast.parse("".join(source))
|
|
||||||
|
|
||||||
for root in tree.body:
|
|
||||||
if isinstance(root, ast.FunctionDef) and root.name.startswith(prefix):
|
|
||||||
|
|
||||||
# only ast.expr and ast.stmt have line numbers, see:
|
|
||||||
# https://docs.python.org/2/library/ast.html#ast.AST.lineno
|
|
||||||
line_numbers = []
|
|
||||||
|
|
||||||
for node in ast.walk(root):
|
|
||||||
if hasattr(node, "lineno"):
|
|
||||||
line_numbers.append(node.lineno)
|
|
||||||
|
|
||||||
body = source[min(line_numbers)-1:max(line_numbers)]
|
|
||||||
while not body[0][0].isspace():
|
|
||||||
body = body[1:]
|
|
||||||
|
|
||||||
# make sure we are inside an indented function body
|
|
||||||
assert all([l[0].isspace() for l in body])
|
|
||||||
|
|
||||||
offset = 0
|
|
||||||
for line in body:
|
|
||||||
match = re.search(r"[^\s]", line)
|
|
||||||
if match:
|
|
||||||
offset = match.start(0)
|
|
||||||
break
|
|
||||||
|
|
||||||
# remove indentation
|
|
||||||
assert offset > 0
|
|
||||||
|
|
||||||
for i in range(len(body)):
|
|
||||||
body[i] = body[i][offset:] if len(body[i]) > offset else "\n"
|
|
||||||
|
|
||||||
# make sure empty lines contain a newline
|
|
||||||
assert all([l[-1] == "\n" for l in body])
|
|
||||||
|
|
||||||
code_filename = "%s.%s" % (name, root.name[len(prefix):])
|
|
||||||
|
|
||||||
with open(os.path.join(dst_dirname, code_filename), "w") as f:
|
|
||||||
f.write(escape("".join(body)))
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user