Merge remote-tracking branch 'refs/remotes/honnibal/master'

This commit is contained in:
maxirmx 2015-10-20 23:27:20 +03:00
commit 14b89ff1c5
17 changed files with 779 additions and 93 deletions

View File

@ -11,6 +11,7 @@ import ujson
import codecs import codecs
from preshed.counter import PreshCounter from preshed.counter import PreshCounter
from joblib import Parallel, delayed from joblib import Parallel, delayed
import io
from spacy.en import English from spacy.en import English
from spacy.strings import StringStore from spacy.strings import StringStore

273
examples/nn_text_class.py Normal file
View File

@ -0,0 +1,273 @@
from __future__ import unicode_literals
from __future__ import print_function
from __future__ import division
from collections import defaultdict
from pathlib import Path
import numpy
import plac
import spacy.en
def read_data(nlp, data_dir):
for subdir, label in (('pos', 1), ('neg', 0)):
for filename in (data_dir / subdir).iterdir():
text = filename.open().read()
doc = nlp(text)
if len(doc) >= 1:
yield doc, label
def partition(examples, split_size):
examples = list(examples)
numpy.random.shuffle(examples)
n_docs = len(examples)
split = int(n_docs * split_size)
return examples[:split], examples[split:]
def minibatch(data, bs=24):
for i in range(0, len(data), bs):
yield data[i:i+bs]
class Extractor(object):
def __init__(self, nlp, vector_length, dropout=0.3):
self.nlp = nlp
self.dropout = dropout
self.vector = numpy.zeros((vector_length, ))
def doc2bow(self, doc, dropout=None):
if dropout is None:
dropout = self.dropout
bow = defaultdict(int)
all_words = defaultdict(int)
for word in doc:
if numpy.random.random() >= dropout and not word.is_punct:
bow[word.lower] += 1
all_words[word.lower] += 1
if sum(bow.values()) >= 1:
return bow
else:
return all_words
def bow2vec(self, bow, E):
self.vector.fill(0)
n = 0
for orth_id, freq in bow.items():
self.vector += self.nlp.vocab[self.nlp.vocab.strings[orth_id]].repvec * freq
# Apply the fine-tuning we've learned
if orth_id < E.shape[0]:
self.vector += E[orth_id] * freq
n += freq
return self.vector / n
class NeuralNetwork(object):
def __init__(self, depth, width, n_classes, n_vocab, extracter, optimizer):
self.depth = depth
self.width = width
self.n_classes = n_classes
self.weights = Params.random(depth, width, width, n_classes, n_vocab)
self.doc2bow = extracter.doc2bow
self.bow2vec = extracter.bow2vec
self.optimizer = optimizer
self._gradient = Params.zero(depth, width, width, n_classes, n_vocab)
self._activity = numpy.zeros((depth, width))
def train(self, batch):
activity = self._activity
gradient = self._gradient
activity.fill(0)
gradient.data.fill(0)
loss = 0
word_freqs = defaultdict(int)
for doc, label in batch:
word_ids = self.doc2bow(doc)
vector = self.bow2vec(word_ids, self.weights.E)
self.forward(activity, vector)
loss += self.backprop(vector, gradient, activity, word_ids, label)
for w, freq in word_ids.items():
word_freqs[w] += freq
self.optimizer(self.weights, gradient, len(batch), word_freqs)
return loss
def predict(self, doc):
actv = self._activity
actv.fill(0)
W = self.weights.W
b = self.weights.b
E = self.weights.E
vector = self.bow2vec(self.doc2bow(doc, dropout=0.0), E)
self.forward(actv, vector)
return numpy.argmax(softmax(actv[-1], W[-1], b[-1]))
def forward(self, actv, in_):
actv.fill(0)
W = self.weights.W; b = self.weights.b
actv[0] = relu(in_, W[0], b[0])
for i in range(1, self.depth):
actv[i] = relu(actv[i-1], W[i], b[i])
def backprop(self, input_vector, gradient, activity, ids, label):
W = self.weights.W
b = self.weights.b
target = numpy.zeros(self.n_classes)
target[label] = 1.0
pred = softmax(activity[-1], W[-1], b[-1])
delta = pred - target
for i in range(self.depth, 0, -1):
gradient.b[i] += delta
gradient.W[i] += numpy.outer(delta, activity[i-1])
delta = d_relu(activity[i-1]) * W[i].T.dot(delta)
gradient.b[0] += delta
gradient.W[0] += numpy.outer(delta, input_vector)
tuning = W[0].T.dot(delta).reshape((self.width,)) / len(ids)
for w, freq in ids.items():
if w < gradient.E.shape[0]:
gradient.E[w] += tuning * freq
return -sum(target * numpy.log(pred))
def softmax(actvn, W, b):
w = W.dot(actvn) + b
ew = numpy.exp(w - max(w))
return (ew / sum(ew)).ravel()
def relu(actvn, W, b):
x = W.dot(actvn) + b
return x * (x > 0)
def d_relu(x):
return x > 0
class Adagrad(object):
def __init__(self, lr, rho):
self.eps = 1e-3
# initial learning rate
self.learning_rate = lr
self.rho = rho
# stores sum of squared gradients
#self.h = numpy.zeros(self.dim)
#self._curr_rate = numpy.zeros(self.h.shape)
self.h = None
self._curr_rate = None
def __call__(self, weights, gradient, batch_size, word_freqs):
if self.h is None:
self.h = numpy.zeros(gradient.data.shape)
self._curr_rate = numpy.zeros(gradient.data.shape)
self.L2_penalty(gradient, weights, word_freqs)
update = self.rescale(gradient.data / batch_size)
weights.data -= update
def rescale(self, gradient):
if self.h is None:
self.h = numpy.zeros(gradient.data.shape)
self._curr_rate = numpy.zeros(gradient.data.shape)
self._curr_rate.fill(0)
self.h += gradient ** 2
self._curr_rate = self.learning_rate / (numpy.sqrt(self.h) + self.eps)
return self._curr_rate * gradient
def L2_penalty(self, gradient, weights, word_freqs):
# L2 Regularization
for i in range(len(weights.W)):
gradient.W[i] += weights.W[i] * self.rho
gradient.b[i] += weights.b[i] * self.rho
for w, freq in word_freqs.items():
if w < gradient.E.shape[0]:
gradient.E[w] += weights.E[w] * self.rho
class Params(object):
@classmethod
def zero(cls, depth, n_embed, n_hidden, n_labels, n_vocab):
return cls(depth, n_embed, n_hidden, n_labels, n_vocab, lambda x: numpy.zeros((x,)))
@classmethod
def random(cls, depth, nE, nH, nL, nV):
return cls(depth, nE, nH, nL, nV, lambda x: (numpy.random.rand(x) * 2 - 1) * 0.08)
def __init__(self, depth, n_embed, n_hidden, n_labels, n_vocab, initializer):
nE = n_embed; nH = n_hidden; nL = n_labels; nV = n_vocab
n_weights = sum([
(nE * nH) + nH,
(nH * nH + nH) * depth,
(nH * nL) + nL,
(nV * nE)
])
self.data = initializer(n_weights)
self.W = []
self.b = []
i = self._add_layer(0, nE, nH)
for _ in range(1, depth):
i = self._add_layer(i, nH, nH)
i = self._add_layer(i, nL, nH)
self.E = self.data[i : i + (nV * nE)].reshape((nV, nE))
self.E.fill(0)
def _add_layer(self, start, x, y):
end = start + (x * y)
self.W.append(self.data[start : end].reshape((x, y)))
self.b.append(self.data[end : end + x].reshape((x, )))
return end + x
@plac.annotations(
data_dir=("Data directory", "positional", None, Path),
n_iter=("Number of iterations (epochs)", "option", "i", int),
width=("Size of hidden layers", "option", "H", int),
depth=("Depth", "option", "d", int),
dropout=("Drop-out rate", "option", "r", float),
rho=("Regularization penalty", "option", "p", float),
eta=("Learning rate", "option", "e", float),
batch_size=("Batch size", "option", "b", int),
vocab_size=("Number of words to fine-tune", "option", "w", int),
)
def main(data_dir, depth=3, width=300, n_iter=5, vocab_size=40000,
batch_size=24, dropout=0.3, rho=1e-5, eta=0.005):
n_classes = 2
print("Loading")
nlp = spacy.en.English(parser=False)
train_data, dev_data = partition(read_data(nlp, data_dir / 'train'), 0.8)
print("Begin training")
extracter = Extractor(nlp, width, dropout=0.3)
optimizer = Adagrad(eta, rho)
model = NeuralNetwork(depth, width, n_classes, vocab_size, extracter, optimizer)
prev_best = 0
best_weights = None
for epoch in range(n_iter):
numpy.random.shuffle(train_data)
train_loss = 0.0
for batch in minibatch(train_data, bs=batch_size):
train_loss += model.train(batch)
n_correct = sum(model.predict(x) == y for x, y in dev_data)
print(epoch, train_loss, n_correct / len(dev_data))
if n_correct >= prev_best:
best_weights = model.weights.data.copy()
prev_best = n_correct
model.weights.data = best_weights
print("Evaluating")
eval_data = list(read_data(nlp, data_dir / 'test'))
n_correct = sum(model.predict(x) == y for x, y in eval_data)
print(n_correct / len(eval_data))
if __name__ == '__main__':
#import cProfile
#import pstats
#cProfile.runctx("main(Path('data/aclImdb'))", globals(), locals(), "Profile.prof")
#s = pstats.Stats("Profile.prof")
#s.strip_dirs().sort_stats("time").print_stats(100)
plac.call(main)

2
fabfile.py vendored
View File

@ -48,7 +48,7 @@ def prebuild(build_dir='/tmp/build_spacy'):
local('virtualenv ' + build_venv) local('virtualenv ' + build_venv)
with prefix('cd %s && PYTHONPATH=`pwd` && . %s/bin/activate' % (build_dir, build_venv)): with prefix('cd %s && PYTHONPATH=`pwd` && . %s/bin/activate' % (build_dir, build_venv)):
local('pip install cython fabric fabtools pytest') local('pip install cython fabric fabtools pytest')
local('pip install -r requirements.txt') local('pip install --no-cache-dir -r requirements.txt')
local('fab clean make') local('fab clean make')
local('cp -r %s/corpora/en/wordnet corpora/en/' % spacy_dir) local('cp -r %s/corpora/en/wordnet corpora/en/' % spacy_dir)
local('cp %s/corpora/en/freqs.txt.gz corpora/en/' % spacy_dir) local('cp %s/corpora/en/freqs.txt.gz corpora/en/' % spacy_dir)

View File

@ -342,7 +342,7 @@ hardcoded_specials = {
"\n": [{"F": "\n", "pos": "SP"}], "\n": [{"F": "\n", "pos": "SP"}],
"\t": [{"F": "\t", "pos": "SP"}], "\t": [{"F": "\t", "pos": "SP"}],
" ": [{"F": " ", "pos": "SP"}], " ": [{"F": " ", "pos": "SP"}],
u"\xa0": [{"F": u"\xa0", "pos": "SP", "L": " "}] u"\u00a0": [{"F": u"\u00a0", "pos": "SP", "L": " "}]
} }

View File

@ -1,3 +1,4 @@
\.\.\. \.\.\.
(?<=[a-z])\.(?=[A-Z]) (?<=[a-z])\.(?=[A-Z])
(?<=[a-zA-Z])-(?=[a-zA-z]) (?<=[a-zA-Z])-(?=[a-zA-z])
(?<=[0-9])-(?=[0-9])

View File

@ -6,7 +6,6 @@ thinc == 3.3
murmurhash == 0.24 murmurhash == 0.24
text-unidecode text-unidecode
numpy numpy
wget
plac plac
six six
ujson ujson

View File

@ -162,7 +162,7 @@ def run_setup(exts):
ext_modules=exts, ext_modules=exts,
license="MIT", license="MIT",
install_requires=['numpy', 'murmurhash', 'cymem >= 1.11', 'preshed >= 0.42', install_requires=['numpy', 'murmurhash', 'cymem >= 1.11', 'preshed >= 0.42',
'thinc == 3.3', "text_unidecode", 'wget', 'plac', 'six', 'thinc == 3.3', "text_unidecode", 'plac', 'six',
'ujson', 'cloudpickle'], 'ujson', 'cloudpickle'],
setup_requires=["headers_workaround"], setup_requires=["headers_workaround"],
cmdclass = {'build_ext': build_ext_subclass }, cmdclass = {'build_ext': build_ext_subclass },
@ -175,13 +175,14 @@ def run_setup(exts):
headers_workaround.install_headers('numpy') headers_workaround.install_headers('numpy')
VERSION = '0.95' VERSION = '0.96'
def main(modules, is_pypy): def main(modules, is_pypy):
language = "cpp" language = "cpp"
includes = ['.', path.join(sys.prefix, 'include')] includes = ['.', path.join(sys.prefix, 'include')]
if sys.platform.startswith('darwin'): if sys.platform.startswith('darwin'):
compile_options['other'].append(['-mmacosx-version-min=10.8', '-stdlib=libc++']) compile_options['other'].append('-mmacosx-version-min=10.8')
link_opions['other'].append('-lc++') compile_options['other'].append('-stdlib=libc++')
link_options['other'].append('-lc++')
if use_cython: if use_cython:
cython_setup(modules, language, includes) cython_setup(modules, language, includes)
else: else:

View File

@ -1,11 +1,13 @@
from __future__ import print_function from __future__ import print_function
from os import path from os import path
import sys
import os import os
import tarfile import tarfile
import shutil import shutil
import wget
import plac import plac
from . import uget
# TODO: Read this from the same source as the setup # TODO: Read this from the same source as the setup
VERSION = '0.9.5' VERSION = '0.9.5'
@ -13,39 +15,45 @@ AWS_STORE = 'https://s3-us-west-1.amazonaws.com/media.spacynlp.com'
ALL_DATA_DIR_URL = '%s/en_data_all-%s.tgz' % (AWS_STORE, VERSION) ALL_DATA_DIR_URL = '%s/en_data_all-%s.tgz' % (AWS_STORE, VERSION)
DEST_DIR = path.join(path.dirname(__file__), 'data') DEST_DIR = path.join(path.dirname(path.abspath(__file__)), 'data')
def download_file(url, out):
wget.download(url, out=out) def download_file(url, dest_dir):
return url.rsplit('/', 1)[1] return uget.download(url, dest_dir, console=sys.stdout)
def install_data(url, dest_dir): def install_data(url, dest_dir):
filename = download_file(url, dest_dir) filename = download_file(url, dest_dir)
t = tarfile.open(path.join(dest_dir, filename)) t = tarfile.open(filename)
t.extractall(dest_dir) t.extractall(dest_dir)
def install_parser_model(url, dest_dir): def install_parser_model(url, dest_dir):
filename = download_file(url, dest_dir) filename = download_file(url, dest_dir)
t = tarfile.open(path.join(dest_dir, filename), mode=":gz") t = tarfile.open(filename, mode=":gz")
t.extractall(path.dirname(__file__)) t.extractall(dest_dir)
def install_dep_vectors(url, dest_dir): def install_dep_vectors(url, dest_dir):
if not os.path.exists(dest_dir): download_file(url, dest_dir)
os.mkdir(dest_dir)
filename = download_file(url, dest_dir)
def main(data_size='all'): @plac.annotations(
force=("Force overwrite", "flag", "f", bool),
)
def main(data_size='all', force=False):
if data_size == 'all': if data_size == 'all':
data_url = ALL_DATA_DIR_URL data_url = ALL_DATA_DIR_URL
elif data_size == 'small': elif data_size == 'small':
data_url = SM_DATA_DIR_URL data_url = SM_DATA_DIR_URL
if path.exists(DEST_DIR):
if force and path.exists(DEST_DIR):
shutil.rmtree(DEST_DIR) shutil.rmtree(DEST_DIR)
install_data(data_url, path.dirname(DEST_DIR))
if not os.path.exists(DEST_DIR):
os.makedirs(DEST_DIR)
install_data(data_url, DEST_DIR)
if __name__ == '__main__': if __name__ == '__main__':

246
spacy/en/uget.py Normal file
View File

@ -0,0 +1,246 @@
import os
import time
import io
import math
import re
try:
from urllib.parse import urlparse
from urllib.request import urlopen, Request
from urllib.error import HTTPError
except ImportError:
from urllib2 import urlopen, urlparse, Request, HTTPError
class UnknownContentLengthException(Exception): pass
class InvalidChecksumException(Exception): pass
class UnsupportedHTTPCodeException(Exception): pass
class InvalidOffsetException(Exception): pass
class MissingChecksumHeader(Exception): pass
CHUNK_SIZE = 16 * 1024
class RateSampler(object):
def __init__(self, period=1):
self.rate = None
self.reset = True
self.period = period
def __enter__(self):
if self.reset:
self.reset = False
self.start = time.time()
self.counter = 0
def __exit__(self, type, value, traceback):
elapsed = time.time() - self.start
if elapsed >= self.period:
self.reset = True
self.rate = float(self.counter) / elapsed
def update(self, value):
self.counter += value
def format(self, unit="MB"):
if self.rate is None:
return None
divisor = {'MB': 1048576, 'kB': 1024}
return "%0.2f%s/s" % (self.rate / divisor[unit], unit)
class TimeEstimator(object):
def __init__(self, cooldown=1):
self.cooldown = cooldown
self.start = time.time()
self.time_left = None
def update(self, bytes_read, total_size):
elapsed = time.time() - self.start
if elapsed > self.cooldown:
self.time_left = math.ceil(elapsed * total_size /
bytes_read - elapsed)
def format(self):
if self.time_left is None:
return None
res = "eta "
if self.time_left / 60 >= 1:
res += "%dm " % (self.time_left / 60)
return res + "%ds" % (self.time_left % 60)
def format_bytes_read(bytes_read, unit="MB"):
divisor = {'MB': 1048576, 'kB': 1024}
return "%0.2f%s" % (float(bytes_read) / divisor[unit], unit)
def format_percent(bytes_read, total_size):
percent = round(bytes_read * 100.0 / total_size, 2)
return "%0.2f%%" % percent
def get_content_range(response):
content_range = response.headers.get('Content-Range', "").strip()
if content_range:
m = re.match(r"bytes (\d+)-(\d+)/(\d+)", content_range)
if m:
return [int(v) for v in m.groups()]
def get_content_length(response):
if 'Content-Length' not in response.headers:
raise UnknownContentLengthException
return int(response.headers.get('Content-Length').strip())
def get_url_meta(url, checksum_header=None):
class HeadRequest(Request):
def get_method(self):
return "HEAD"
r = urlopen(HeadRequest(url))
res = {'size': get_content_length(r)}
if checksum_header:
value = r.headers.get(checksum_header)
if value:
res['checksum'] = value
r.close()
return res
def progress(console, bytes_read, total_size, transfer_rate, eta):
fields = [
format_bytes_read(bytes_read),
format_percent(bytes_read, total_size),
transfer_rate.format(),
eta.format(),
" " * 10,
]
console.write("Downloaded %s\r" % " ".join(filter(None, fields)))
console.flush()
def read_request(request, offset=0, console=None,
progress_func=None, write_func=None):
# support partial downloads
if offset > 0:
request.add_header('Range', "bytes=%s-" % offset)
try:
response = urlopen(request)
except HTTPError as e:
if e.code == 416: # Requested Range Not Satisfiable
raise InvalidOffsetException
# TODO add http error handling here
raise UnsupportedHTTPCodeException(e.code)
total_size = get_content_length(response) + offset
bytes_read = offset
# sanity checks
if response.code == 200: # OK
assert offset == 0
elif response.code == 206: # Partial content
range_start, range_end, range_total = get_content_range(response)
assert range_start == offset
assert range_total == total_size
assert range_end + 1 - range_start == total_size - bytes_read
else:
raise UnsupportedHTTPCodeException(response.code)
eta = TimeEstimator()
transfer_rate = RateSampler()
if console:
if offset > 0:
console.write("Continue downloading...\n")
else:
console.write("Downloading...\n")
while True:
with transfer_rate:
chunk = response.read(CHUNK_SIZE)
if not chunk:
if progress_func and console:
console.write('\n')
break
bytes_read += len(chunk)
transfer_rate.update(len(chunk))
eta.update(bytes_read - offset, total_size - offset)
if progress_func and console:
progress_func(console, bytes_read, total_size, transfer_rate, eta)
if write_func:
write_func(chunk)
response.close()
assert bytes_read == total_size
return response
def download(url, path=".",
checksum=None, checksum_header=None,
headers=None, console=None):
if os.path.isdir(path):
path = os.path.join(path, url.rsplit('/', 1)[1])
path = os.path.abspath(path)
with io.open(path, "a+b") as f:
size = f.tell()
# update checksum of partially downloaded file
if checksum:
f.seek(0, os.SEEK_SET)
for chunk in iter(lambda: f.read(CHUNK_SIZE), b""):
checksum.update(chunk)
def write(chunk):
if checksum:
checksum.update(chunk)
f.write(chunk)
request = Request(url)
# request headers
if headers:
for key, value in headers.items():
request.add_header(key, value)
try:
response = read_request(request,
offset=size,
console=console,
progress_func=progress,
write_func=write)
except InvalidOffsetException:
response = None
if checksum:
if response:
origin_checksum = response.headers.get(checksum_header)
else:
# check whether file is already complete
meta = get_url_meta(url, checksum_header)
origin_checksum = meta.get('checksum')
if origin_checksum is None:
raise MissingChecksumHeader
if checksum.hexdigest() != origin_checksum:
raise InvalidChecksumException
if console:
console.write("checksum/sha256 OK\n")
return path

View File

@ -20,8 +20,6 @@ from .tokens.doc cimport get_token_attr
from .tokens.doc cimport Doc from .tokens.doc cimport Doc
from .vocab cimport Vocab from .vocab cimport Vocab
from libcpp.vector cimport vector
from .attrs import FLAG61 as U_ENT from .attrs import FLAG61 as U_ENT
from .attrs import FLAG60 as B2_ENT from .attrs import FLAG60 as B2_ENT
@ -221,8 +219,7 @@ cdef class Matcher:
q = 0 q = 0
# Go over the open matches, extending or finalizing if able. Otherwise, # Go over the open matches, extending or finalizing if able. Otherwise,
# we over-write them (q doesn't advance) # we over-write them (q doesn't advance)
for i in range(partials.size()): for state in partials:
state = partials.at(i)
if match(state, token): if match(state, token):
if is_final(state): if is_final(state):
label, start, end = get_entity(state, token, token_i) label, start, end = get_entity(state, token, token_i)
@ -233,8 +230,7 @@ cdef class Matcher:
q += 1 q += 1
partials.resize(q) partials.resize(q)
# Check whether we open any new patterns on this token # Check whether we open any new patterns on this token
for i in range(self.n_patterns): for state in self.patterns:
state = self.patterns[i]
if match(state, token): if match(state, token):
if is_final(state): if is_final(state):
label, start, end = get_entity(state, token, token_i) label, start, end = get_entity(state, token, token_i)
@ -242,7 +238,16 @@ cdef class Matcher:
matches.append((label, start, end)) matches.append((label, start, end))
else: else:
partials.push_back(state + 1) partials.push_back(state + 1)
doc.ents = [(e.label, e.start, e.end) for e in doc.ents] + matches seen = set()
filtered = []
for label, start, end in sorted(matches, key=lambda m: (m[1], -(m[1] - m[2]))):
if all(i in seen for i in range(start, end)):
continue
else:
for i in range(start, end):
seen.add(i)
filtered.append((label, start, end))
doc.ents = [(e.label, e.start, e.end) for e in doc.ents] + filtered
return matches return matches

View File

@ -72,6 +72,10 @@ cdef class Tokenizer:
Returns: Returns:
tokens (Doc): A Doc object, giving access to a sequence of LexemeCs. tokens (Doc): A Doc object, giving access to a sequence of LexemeCs.
""" """
if len(string) >= (2 ** 30):
raise ValueError(
"String is too long: %d characters. Max is 2**30." % len(string)
)
cdef int length = len(string) cdef int length = len(string)
cdef Doc tokens = Doc(self.vocab) cdef Doc tokens = Doc(self.vocab)
if length == 0: if length == 0:

View File

@ -447,9 +447,9 @@ cdef class Doc:
cdef Span span = self[start:end] cdef Span span = self[start:end]
# Get LexemeC for newly merged token # Get LexemeC for newly merged token
new_orth = ''.join([t.string for t in span]) new_orth = ''.join([t.text_with_ws for t in span])
if span[-1].whitespace_: if span[-1].whitespace_:
new_orth = new_orth[:-1] new_orth = new_orth[:-len(span[-1].whitespace_)]
cdef const LexemeC* lex = self.vocab.get(self.mem, new_orth) cdef const LexemeC* lex = self.vocab.get(self.mem, new_orth)
# House the new merged token where it starts # House the new merged token where it starts
cdef TokenC* token = &self.data[start] cdef TokenC* token = &self.data[start]
@ -508,16 +508,26 @@ cdef int set_children_from_heads(TokenC* tokens, int length) except -1:
cdef TokenC* head cdef TokenC* head
cdef TokenC* child cdef TokenC* child
cdef int i cdef int i
# Set number of left/right children to 0. We'll increment it in the loops.
for i in range(length):
tokens[i].l_kids = 0
tokens[i].r_kids = 0
tokens[i].l_edge = i
tokens[i].r_edge = i
# Set left edges # Set left edges
for i in range(length): for i in range(length):
child = &tokens[i] child = &tokens[i]
head = &tokens[i + child.head] head = &tokens[i + child.head]
if child < head and child.l_edge < head.l_edge: if child < head:
head.l_edge = child.l_edge if child.l_edge < head.l_edge:
head.l_edge = child.l_edge
head.l_kids += 1
# Set right edges --- same as above, but iterate in reverse # Set right edges --- same as above, but iterate in reverse
for i in range(length-1, -1, -1): for i in range(length-1, -1, -1):
child = &tokens[i] child = &tokens[i]
head = &tokens[i + child.head] head = &tokens[i + child.head]
if child > head and child.r_edge > head.r_edge: if child > head:
head.r_edge = child.r_edge if child.r_edge > head.r_edge:
head.r_edge = child.r_edge
head.r_kids += 1

View File

@ -278,7 +278,7 @@ cdef class Token:
property whitespace_: property whitespace_:
def __get__(self): def __get__(self):
return self.string[self.c.lex.length:] return ' ' if self.c.spacy else ''
property orth_: property orth_:
def __get__(self): def __get__(self):

View File

@ -1,17 +1,102 @@
import pytest import pytest
from spacy.matcher import Matcher from spacy.matcher import Matcher
from spacy.attrs import LOWER
@pytest.mark.xfail
def test_overlap_issue118(EN): def test_overlap_issue118(EN):
'''Test a bug that arose from having overlapping matches''' '''Test a bug that arose from having overlapping matches'''
doc = EN.tokenizer(u'how many points did lebron james score against the boston celtics last night') doc = EN.tokenizer(u'how many points did lebron james score against the boston celtics last night')
ORG = doc.vocab.strings['ORG'] ORG = doc.vocab.strings['ORG']
matcher = Matcher(EN.vocab, {'BostonCeltics': ('ORG', {}, [[{'lower': 'boston'}, {'lower': 'celtics'}], [{'lower': 'celtics'}]])}) matcher = Matcher(EN.vocab,
{'BostonCeltics':
('ORG', {},
[
[{LOWER: 'celtics'}],
[{LOWER: 'boston'}, {LOWER: 'celtics'}],
]
)
}
)
assert len(list(doc.ents)) == 0
matches = matcher(doc) matches = matcher(doc)
assert matches == [(ORG, 9, 11)] assert matches == [(ORG, 9, 11), (ORG, 10, 11)]
ents = list(doc.ents)
assert len(ents) == 1
assert ents[0].label == ORG
assert ents[0].start == 9
assert ents[0].end == 11
def test_overlap_reorder(EN):
'''Test order dependence'''
doc = EN.tokenizer(u'how many points did lebron james score against the boston celtics last night')
ORG = doc.vocab.strings['ORG']
matcher = Matcher(EN.vocab,
{'BostonCeltics':
('ORG', {},
[
[{LOWER: 'boston'}, {LOWER: 'celtics'}],
[{LOWER: 'celtics'}],
]
)
}
)
assert len(list(doc.ents)) == 0
matches = matcher(doc)
assert matches == [(ORG, 9, 11), (ORG, 10, 11)]
ents = list(doc.ents)
assert len(ents) == 1
assert ents[0].label == ORG
assert ents[0].start == 9
assert ents[0].end == 11
def test_overlap_prefix(EN):
'''Test order dependence'''
doc = EN.tokenizer(u'how many points did lebron james score against the boston celtics last night')
ORG = doc.vocab.strings['ORG']
matcher = Matcher(EN.vocab,
{'BostonCeltics':
('ORG', {},
[
[{LOWER: 'boston'}],
[{LOWER: 'boston'}, {LOWER: 'celtics'}],
]
)
}
)
assert len(list(doc.ents)) == 0
matches = matcher(doc)
assert matches == [(ORG, 9, 10), (ORG, 9, 11)]
ents = list(doc.ents)
assert len(ents) == 1
assert ents[0].label == ORG
assert ents[0].start == 9
assert ents[0].end == 11
def test_overlap_prefix_reorder(EN):
'''Test order dependence'''
doc = EN.tokenizer(u'how many points did lebron james score against the boston celtics last night')
ORG = doc.vocab.strings['ORG']
matcher = Matcher(EN.vocab,
{'BostonCeltics':
('ORG', {},
[
[{LOWER: 'boston'}, {LOWER: 'celtics'}],
[{LOWER: 'boston'}],
]
)
}
)
assert len(list(doc.ents)) == 0
matches = matcher(doc)
assert matches == [(ORG, 9, 10), (ORG, 9, 11)]
ents = list(doc.ents) ents = list(doc.ents)
assert len(ents) == 1 assert len(ents) == 1
assert ents[0].label == ORG assert ents[0].label == ORG

View File

@ -7,6 +7,10 @@ def test_hyphen(en_tokenizer):
assert len(tokens) == 3 assert len(tokens) == 3
def test_numeric_range(en_tokenizer):
tokens = en_tokenizer('0.1-13.5')
assert len(tokens) == 3
def test_period(en_tokenizer): def test_period(en_tokenizer):
tokens = en_tokenizer('best.Known') tokens = en_tokenizer('best.Known')
assert len(tokens) == 3 assert len(tokens) == 3

View File

@ -109,3 +109,42 @@ def test_set_ents(EN):
assert ent.label_ == 'PRODUCT' assert ent.label_ == 'PRODUCT'
assert ent.start == 2 assert ent.start == 2
assert ent.end == 4 assert ent.end == 4
def test_merge(EN):
doc = EN('WKRO played songs by the beach boys all night')
assert len(doc) == 9
# merge 'The Beach Boys'
doc.merge(doc[4].idx, doc[6].idx + len(doc[6]), 'NAMED', 'LEMMA', 'TYPE')
assert len(doc) == 7
assert doc[4].text == 'the beach boys'
assert doc[4].text_with_ws == 'the beach boys '
assert doc[4].tag_ == 'NAMED'
def test_merge_end_string(EN):
doc = EN('WKRO played songs by the beach boys all night')
assert len(doc) == 9
# merge 'The Beach Boys'
doc.merge(doc[7].idx, doc[8].idx + len(doc[8]), 'NAMED', 'LEMMA', 'TYPE')
assert len(doc) == 8
assert doc[7].text == 'all night'
assert doc[7].text_with_ws == 'all night'
@pytest.mark.models
def test_merge_children(EN):
"""Test that attachments work correctly after merging."""
doc = EN('WKRO played songs by the beach boys all night')
# merge 'The Beach Boys'
doc.merge(doc[4].idx, doc[6].idx + len(doc[6]), 'NAMED', 'LEMMA', 'TYPE')
for word in doc:
if word.i < word.head.i:
assert word in list(word.head.lefts)
elif word.i > word.head.i:
assert word in list(word.head.rights)

View File

@ -1,8 +1,11 @@
#!/usr/bin/env python #!/usr/bin/env python
import sys from __future__ import unicode_literals
import re
import os import os
import ast import ast
import io
import plac
# cgi.escape is deprecated since py32 # cgi.escape is deprecated since py32
try: try:
@ -11,55 +14,62 @@ except ImportError:
from cgi import escape from cgi import escape
src_dirname = sys.argv[1] # e.g. python website/create_code_samples tests/website/ website/src/
dst_dirname = sys.argv[2] def main(src_dirname, dst_dirname):
prefix = "test_" prefix = "test_"
for filename in os.listdir(src_dirname):
if not filename.startswith('test_'):
continue
if not filename.endswith('.py'):
continue
# Remove test_ prefix and .py suffix
name = filename[6:-3]
with io.open(os.path.join(src_dirname, filename), 'r', encoding='utf8') as file_:
source = file_.readlines()
tree = ast.parse("".join(source))
for root in tree.body:
if isinstance(root, ast.FunctionDef) and root.name.startswith(prefix):
# only ast.expr and ast.stmt have line numbers, see:
# https://docs.python.org/2/library/ast.html#ast.AST.lineno
line_numbers = []
for node in ast.walk(root):
if hasattr(node, "lineno"):
line_numbers.append(node.lineno)
body = source[min(line_numbers)-1:max(line_numbers)]
while not body[0][0].isspace():
body = body[1:]
# make sure we are inside an indented function body
assert all([l[0].isspace() for l in body])
offset = 0
for line in body:
match = re.search(r"[^\s]", line)
if match:
offset = match.start(0)
break
# remove indentation
assert offset > 0
for i in range(len(body)):
body[i] = body[i][offset:] if len(body[i]) > offset else "\n"
# make sure empty lines contain a newline
assert all([l[-1] == "\n" for l in body])
code_filename = "%s.%s" % (name, root.name[len(prefix):])
with io.open(os.path.join(dst_dirname, code_filename),
"w", encoding='utf8') as f:
f.write(escape("".join(body)))
for filename in os.listdir(src_dirname): if __name__ == '__main__':
match = re.match(re.escape(prefix) + r"(.+)\.py$", filename) plac.call(main)
if not match:
continue
name = match.group(1)
source = open(os.path.join(src_dirname, filename)).readlines()
tree = ast.parse("".join(source))
for root in tree.body:
if isinstance(root, ast.FunctionDef) and root.name.startswith(prefix):
# only ast.expr and ast.stmt have line numbers, see:
# https://docs.python.org/2/library/ast.html#ast.AST.lineno
line_numbers = []
for node in ast.walk(root):
if hasattr(node, "lineno"):
line_numbers.append(node.lineno)
body = source[min(line_numbers)-1:max(line_numbers)]
while not body[0][0].isspace():
body = body[1:]
# make sure we are inside an indented function body
assert all([l[0].isspace() for l in body])
offset = 0
for line in body:
match = re.search(r"[^\s]", line)
if match:
offset = match.start(0)
break
# remove indentation
assert offset > 0
for i in range(len(body)):
body[i] = body[i][offset:] if len(body[i]) > offset else "\n"
# make sure empty lines contain a newline
assert all([l[-1] == "\n" for l in body])
code_filename = "%s.%s" % (name, root.name[len(prefix):])
with open(os.path.join(dst_dirname, code_filename), "w") as f:
f.write(escape("".join(body)))