Merge remote-tracking branch 'refs/remotes/honnibal/master'

This commit is contained in:
maxirmx 2015-10-20 23:27:20 +03:00
commit 14b89ff1c5
17 changed files with 779 additions and 93 deletions

View File

@ -11,6 +11,7 @@ import ujson
import codecs
from preshed.counter import PreshCounter
from joblib import Parallel, delayed
import io
from spacy.en import English
from spacy.strings import StringStore

273
examples/nn_text_class.py Normal file
View File

@ -0,0 +1,273 @@
from __future__ import unicode_literals
from __future__ import print_function
from __future__ import division
from collections import defaultdict
from pathlib import Path
import numpy
import plac
import spacy.en
def read_data(nlp, data_dir):
for subdir, label in (('pos', 1), ('neg', 0)):
for filename in (data_dir / subdir).iterdir():
text = filename.open().read()
doc = nlp(text)
if len(doc) >= 1:
yield doc, label
def partition(examples, split_size):
examples = list(examples)
numpy.random.shuffle(examples)
n_docs = len(examples)
split = int(n_docs * split_size)
return examples[:split], examples[split:]
def minibatch(data, bs=24):
for i in range(0, len(data), bs):
yield data[i:i+bs]
class Extractor(object):
def __init__(self, nlp, vector_length, dropout=0.3):
self.nlp = nlp
self.dropout = dropout
self.vector = numpy.zeros((vector_length, ))
def doc2bow(self, doc, dropout=None):
if dropout is None:
dropout = self.dropout
bow = defaultdict(int)
all_words = defaultdict(int)
for word in doc:
if numpy.random.random() >= dropout and not word.is_punct:
bow[word.lower] += 1
all_words[word.lower] += 1
if sum(bow.values()) >= 1:
return bow
else:
return all_words
def bow2vec(self, bow, E):
self.vector.fill(0)
n = 0
for orth_id, freq in bow.items():
self.vector += self.nlp.vocab[self.nlp.vocab.strings[orth_id]].repvec * freq
# Apply the fine-tuning we've learned
if orth_id < E.shape[0]:
self.vector += E[orth_id] * freq
n += freq
return self.vector / n
class NeuralNetwork(object):
def __init__(self, depth, width, n_classes, n_vocab, extracter, optimizer):
self.depth = depth
self.width = width
self.n_classes = n_classes
self.weights = Params.random(depth, width, width, n_classes, n_vocab)
self.doc2bow = extracter.doc2bow
self.bow2vec = extracter.bow2vec
self.optimizer = optimizer
self._gradient = Params.zero(depth, width, width, n_classes, n_vocab)
self._activity = numpy.zeros((depth, width))
def train(self, batch):
activity = self._activity
gradient = self._gradient
activity.fill(0)
gradient.data.fill(0)
loss = 0
word_freqs = defaultdict(int)
for doc, label in batch:
word_ids = self.doc2bow(doc)
vector = self.bow2vec(word_ids, self.weights.E)
self.forward(activity, vector)
loss += self.backprop(vector, gradient, activity, word_ids, label)
for w, freq in word_ids.items():
word_freqs[w] += freq
self.optimizer(self.weights, gradient, len(batch), word_freqs)
return loss
def predict(self, doc):
actv = self._activity
actv.fill(0)
W = self.weights.W
b = self.weights.b
E = self.weights.E
vector = self.bow2vec(self.doc2bow(doc, dropout=0.0), E)
self.forward(actv, vector)
return numpy.argmax(softmax(actv[-1], W[-1], b[-1]))
def forward(self, actv, in_):
actv.fill(0)
W = self.weights.W; b = self.weights.b
actv[0] = relu(in_, W[0], b[0])
for i in range(1, self.depth):
actv[i] = relu(actv[i-1], W[i], b[i])
def backprop(self, input_vector, gradient, activity, ids, label):
W = self.weights.W
b = self.weights.b
target = numpy.zeros(self.n_classes)
target[label] = 1.0
pred = softmax(activity[-1], W[-1], b[-1])
delta = pred - target
for i in range(self.depth, 0, -1):
gradient.b[i] += delta
gradient.W[i] += numpy.outer(delta, activity[i-1])
delta = d_relu(activity[i-1]) * W[i].T.dot(delta)
gradient.b[0] += delta
gradient.W[0] += numpy.outer(delta, input_vector)
tuning = W[0].T.dot(delta).reshape((self.width,)) / len(ids)
for w, freq in ids.items():
if w < gradient.E.shape[0]:
gradient.E[w] += tuning * freq
return -sum(target * numpy.log(pred))
def softmax(actvn, W, b):
w = W.dot(actvn) + b
ew = numpy.exp(w - max(w))
return (ew / sum(ew)).ravel()
def relu(actvn, W, b):
x = W.dot(actvn) + b
return x * (x > 0)
def d_relu(x):
return x > 0
class Adagrad(object):
def __init__(self, lr, rho):
self.eps = 1e-3
# initial learning rate
self.learning_rate = lr
self.rho = rho
# stores sum of squared gradients
#self.h = numpy.zeros(self.dim)
#self._curr_rate = numpy.zeros(self.h.shape)
self.h = None
self._curr_rate = None
def __call__(self, weights, gradient, batch_size, word_freqs):
if self.h is None:
self.h = numpy.zeros(gradient.data.shape)
self._curr_rate = numpy.zeros(gradient.data.shape)
self.L2_penalty(gradient, weights, word_freqs)
update = self.rescale(gradient.data / batch_size)
weights.data -= update
def rescale(self, gradient):
if self.h is None:
self.h = numpy.zeros(gradient.data.shape)
self._curr_rate = numpy.zeros(gradient.data.shape)
self._curr_rate.fill(0)
self.h += gradient ** 2
self._curr_rate = self.learning_rate / (numpy.sqrt(self.h) + self.eps)
return self._curr_rate * gradient
def L2_penalty(self, gradient, weights, word_freqs):
# L2 Regularization
for i in range(len(weights.W)):
gradient.W[i] += weights.W[i] * self.rho
gradient.b[i] += weights.b[i] * self.rho
for w, freq in word_freqs.items():
if w < gradient.E.shape[0]:
gradient.E[w] += weights.E[w] * self.rho
class Params(object):
@classmethod
def zero(cls, depth, n_embed, n_hidden, n_labels, n_vocab):
return cls(depth, n_embed, n_hidden, n_labels, n_vocab, lambda x: numpy.zeros((x,)))
@classmethod
def random(cls, depth, nE, nH, nL, nV):
return cls(depth, nE, nH, nL, nV, lambda x: (numpy.random.rand(x) * 2 - 1) * 0.08)
def __init__(self, depth, n_embed, n_hidden, n_labels, n_vocab, initializer):
nE = n_embed; nH = n_hidden; nL = n_labels; nV = n_vocab
n_weights = sum([
(nE * nH) + nH,
(nH * nH + nH) * depth,
(nH * nL) + nL,
(nV * nE)
])
self.data = initializer(n_weights)
self.W = []
self.b = []
i = self._add_layer(0, nE, nH)
for _ in range(1, depth):
i = self._add_layer(i, nH, nH)
i = self._add_layer(i, nL, nH)
self.E = self.data[i : i + (nV * nE)].reshape((nV, nE))
self.E.fill(0)
def _add_layer(self, start, x, y):
end = start + (x * y)
self.W.append(self.data[start : end].reshape((x, y)))
self.b.append(self.data[end : end + x].reshape((x, )))
return end + x
@plac.annotations(
data_dir=("Data directory", "positional", None, Path),
n_iter=("Number of iterations (epochs)", "option", "i", int),
width=("Size of hidden layers", "option", "H", int),
depth=("Depth", "option", "d", int),
dropout=("Drop-out rate", "option", "r", float),
rho=("Regularization penalty", "option", "p", float),
eta=("Learning rate", "option", "e", float),
batch_size=("Batch size", "option", "b", int),
vocab_size=("Number of words to fine-tune", "option", "w", int),
)
def main(data_dir, depth=3, width=300, n_iter=5, vocab_size=40000,
batch_size=24, dropout=0.3, rho=1e-5, eta=0.005):
n_classes = 2
print("Loading")
nlp = spacy.en.English(parser=False)
train_data, dev_data = partition(read_data(nlp, data_dir / 'train'), 0.8)
print("Begin training")
extracter = Extractor(nlp, width, dropout=0.3)
optimizer = Adagrad(eta, rho)
model = NeuralNetwork(depth, width, n_classes, vocab_size, extracter, optimizer)
prev_best = 0
best_weights = None
for epoch in range(n_iter):
numpy.random.shuffle(train_data)
train_loss = 0.0
for batch in minibatch(train_data, bs=batch_size):
train_loss += model.train(batch)
n_correct = sum(model.predict(x) == y for x, y in dev_data)
print(epoch, train_loss, n_correct / len(dev_data))
if n_correct >= prev_best:
best_weights = model.weights.data.copy()
prev_best = n_correct
model.weights.data = best_weights
print("Evaluating")
eval_data = list(read_data(nlp, data_dir / 'test'))
n_correct = sum(model.predict(x) == y for x, y in eval_data)
print(n_correct / len(eval_data))
if __name__ == '__main__':
#import cProfile
#import pstats
#cProfile.runctx("main(Path('data/aclImdb'))", globals(), locals(), "Profile.prof")
#s = pstats.Stats("Profile.prof")
#s.strip_dirs().sort_stats("time").print_stats(100)
plac.call(main)

2
fabfile.py vendored
View File

@ -48,7 +48,7 @@ def prebuild(build_dir='/tmp/build_spacy'):
local('virtualenv ' + build_venv)
with prefix('cd %s && PYTHONPATH=`pwd` && . %s/bin/activate' % (build_dir, build_venv)):
local('pip install cython fabric fabtools pytest')
local('pip install -r requirements.txt')
local('pip install --no-cache-dir -r requirements.txt')
local('fab clean make')
local('cp -r %s/corpora/en/wordnet corpora/en/' % spacy_dir)
local('cp %s/corpora/en/freqs.txt.gz corpora/en/' % spacy_dir)

View File

@ -342,7 +342,7 @@ hardcoded_specials = {
"\n": [{"F": "\n", "pos": "SP"}],
"\t": [{"F": "\t", "pos": "SP"}],
" ": [{"F": " ", "pos": "SP"}],
u"\xa0": [{"F": u"\xa0", "pos": "SP", "L": " "}]
u"\u00a0": [{"F": u"\u00a0", "pos": "SP", "L": " "}]
}

View File

@ -1,3 +1,4 @@
\.\.\.
(?<=[a-z])\.(?=[A-Z])
(?<=[a-zA-Z])-(?=[a-zA-z])
(?<=[0-9])-(?=[0-9])

View File

@ -6,7 +6,6 @@ thinc == 3.3
murmurhash == 0.24
text-unidecode
numpy
wget
plac
six
ujson

View File

@ -162,7 +162,7 @@ def run_setup(exts):
ext_modules=exts,
license="MIT",
install_requires=['numpy', 'murmurhash', 'cymem >= 1.11', 'preshed >= 0.42',
'thinc == 3.3', "text_unidecode", 'wget', 'plac', 'six',
'thinc == 3.3', "text_unidecode", 'plac', 'six',
'ujson', 'cloudpickle'],
setup_requires=["headers_workaround"],
cmdclass = {'build_ext': build_ext_subclass },
@ -175,13 +175,14 @@ def run_setup(exts):
headers_workaround.install_headers('numpy')
VERSION = '0.95'
VERSION = '0.96'
def main(modules, is_pypy):
language = "cpp"
includes = ['.', path.join(sys.prefix, 'include')]
if sys.platform.startswith('darwin'):
compile_options['other'].append(['-mmacosx-version-min=10.8', '-stdlib=libc++'])
link_opions['other'].append('-lc++')
compile_options['other'].append('-mmacosx-version-min=10.8')
compile_options['other'].append('-stdlib=libc++')
link_options['other'].append('-lc++')
if use_cython:
cython_setup(modules, language, includes)
else:

View File

@ -1,11 +1,13 @@
from __future__ import print_function
from os import path
import sys
import os
import tarfile
import shutil
import wget
import plac
from . import uget
# TODO: Read this from the same source as the setup
VERSION = '0.9.5'
@ -13,39 +15,45 @@ AWS_STORE = 'https://s3-us-west-1.amazonaws.com/media.spacynlp.com'
ALL_DATA_DIR_URL = '%s/en_data_all-%s.tgz' % (AWS_STORE, VERSION)
DEST_DIR = path.join(path.dirname(__file__), 'data')
DEST_DIR = path.join(path.dirname(path.abspath(__file__)), 'data')
def download_file(url, out):
wget.download(url, out=out)
return url.rsplit('/', 1)[1]
def download_file(url, dest_dir):
return uget.download(url, dest_dir, console=sys.stdout)
def install_data(url, dest_dir):
filename = download_file(url, dest_dir)
t = tarfile.open(path.join(dest_dir, filename))
t = tarfile.open(filename)
t.extractall(dest_dir)
def install_parser_model(url, dest_dir):
filename = download_file(url, dest_dir)
t = tarfile.open(path.join(dest_dir, filename), mode=":gz")
t.extractall(path.dirname(__file__))
t = tarfile.open(filename, mode=":gz")
t.extractall(dest_dir)
def install_dep_vectors(url, dest_dir):
if not os.path.exists(dest_dir):
os.mkdir(dest_dir)
filename = download_file(url, dest_dir)
download_file(url, dest_dir)
def main(data_size='all'):
@plac.annotations(
force=("Force overwrite", "flag", "f", bool),
)
def main(data_size='all', force=False):
if data_size == 'all':
data_url = ALL_DATA_DIR_URL
elif data_size == 'small':
data_url = SM_DATA_DIR_URL
if path.exists(DEST_DIR):
if force and path.exists(DEST_DIR):
shutil.rmtree(DEST_DIR)
install_data(data_url, path.dirname(DEST_DIR))
if not os.path.exists(DEST_DIR):
os.makedirs(DEST_DIR)
install_data(data_url, DEST_DIR)
if __name__ == '__main__':

246
spacy/en/uget.py Normal file
View File

@ -0,0 +1,246 @@
import os
import time
import io
import math
import re
try:
from urllib.parse import urlparse
from urllib.request import urlopen, Request
from urllib.error import HTTPError
except ImportError:
from urllib2 import urlopen, urlparse, Request, HTTPError
class UnknownContentLengthException(Exception): pass
class InvalidChecksumException(Exception): pass
class UnsupportedHTTPCodeException(Exception): pass
class InvalidOffsetException(Exception): pass
class MissingChecksumHeader(Exception): pass
CHUNK_SIZE = 16 * 1024
class RateSampler(object):
def __init__(self, period=1):
self.rate = None
self.reset = True
self.period = period
def __enter__(self):
if self.reset:
self.reset = False
self.start = time.time()
self.counter = 0
def __exit__(self, type, value, traceback):
elapsed = time.time() - self.start
if elapsed >= self.period:
self.reset = True
self.rate = float(self.counter) / elapsed
def update(self, value):
self.counter += value
def format(self, unit="MB"):
if self.rate is None:
return None
divisor = {'MB': 1048576, 'kB': 1024}
return "%0.2f%s/s" % (self.rate / divisor[unit], unit)
class TimeEstimator(object):
def __init__(self, cooldown=1):
self.cooldown = cooldown
self.start = time.time()
self.time_left = None
def update(self, bytes_read, total_size):
elapsed = time.time() - self.start
if elapsed > self.cooldown:
self.time_left = math.ceil(elapsed * total_size /
bytes_read - elapsed)
def format(self):
if self.time_left is None:
return None
res = "eta "
if self.time_left / 60 >= 1:
res += "%dm " % (self.time_left / 60)
return res + "%ds" % (self.time_left % 60)
def format_bytes_read(bytes_read, unit="MB"):
divisor = {'MB': 1048576, 'kB': 1024}
return "%0.2f%s" % (float(bytes_read) / divisor[unit], unit)
def format_percent(bytes_read, total_size):
percent = round(bytes_read * 100.0 / total_size, 2)
return "%0.2f%%" % percent
def get_content_range(response):
content_range = response.headers.get('Content-Range', "").strip()
if content_range:
m = re.match(r"bytes (\d+)-(\d+)/(\d+)", content_range)
if m:
return [int(v) for v in m.groups()]
def get_content_length(response):
if 'Content-Length' not in response.headers:
raise UnknownContentLengthException
return int(response.headers.get('Content-Length').strip())
def get_url_meta(url, checksum_header=None):
class HeadRequest(Request):
def get_method(self):
return "HEAD"
r = urlopen(HeadRequest(url))
res = {'size': get_content_length(r)}
if checksum_header:
value = r.headers.get(checksum_header)
if value:
res['checksum'] = value
r.close()
return res
def progress(console, bytes_read, total_size, transfer_rate, eta):
fields = [
format_bytes_read(bytes_read),
format_percent(bytes_read, total_size),
transfer_rate.format(),
eta.format(),
" " * 10,
]
console.write("Downloaded %s\r" % " ".join(filter(None, fields)))
console.flush()
def read_request(request, offset=0, console=None,
progress_func=None, write_func=None):
# support partial downloads
if offset > 0:
request.add_header('Range', "bytes=%s-" % offset)
try:
response = urlopen(request)
except HTTPError as e:
if e.code == 416: # Requested Range Not Satisfiable
raise InvalidOffsetException
# TODO add http error handling here
raise UnsupportedHTTPCodeException(e.code)
total_size = get_content_length(response) + offset
bytes_read = offset
# sanity checks
if response.code == 200: # OK
assert offset == 0
elif response.code == 206: # Partial content
range_start, range_end, range_total = get_content_range(response)
assert range_start == offset
assert range_total == total_size
assert range_end + 1 - range_start == total_size - bytes_read
else:
raise UnsupportedHTTPCodeException(response.code)
eta = TimeEstimator()
transfer_rate = RateSampler()
if console:
if offset > 0:
console.write("Continue downloading...\n")
else:
console.write("Downloading...\n")
while True:
with transfer_rate:
chunk = response.read(CHUNK_SIZE)
if not chunk:
if progress_func and console:
console.write('\n')
break
bytes_read += len(chunk)
transfer_rate.update(len(chunk))
eta.update(bytes_read - offset, total_size - offset)
if progress_func and console:
progress_func(console, bytes_read, total_size, transfer_rate, eta)
if write_func:
write_func(chunk)
response.close()
assert bytes_read == total_size
return response
def download(url, path=".",
checksum=None, checksum_header=None,
headers=None, console=None):
if os.path.isdir(path):
path = os.path.join(path, url.rsplit('/', 1)[1])
path = os.path.abspath(path)
with io.open(path, "a+b") as f:
size = f.tell()
# update checksum of partially downloaded file
if checksum:
f.seek(0, os.SEEK_SET)
for chunk in iter(lambda: f.read(CHUNK_SIZE), b""):
checksum.update(chunk)
def write(chunk):
if checksum:
checksum.update(chunk)
f.write(chunk)
request = Request(url)
# request headers
if headers:
for key, value in headers.items():
request.add_header(key, value)
try:
response = read_request(request,
offset=size,
console=console,
progress_func=progress,
write_func=write)
except InvalidOffsetException:
response = None
if checksum:
if response:
origin_checksum = response.headers.get(checksum_header)
else:
# check whether file is already complete
meta = get_url_meta(url, checksum_header)
origin_checksum = meta.get('checksum')
if origin_checksum is None:
raise MissingChecksumHeader
if checksum.hexdigest() != origin_checksum:
raise InvalidChecksumException
if console:
console.write("checksum/sha256 OK\n")
return path

View File

@ -20,8 +20,6 @@ from .tokens.doc cimport get_token_attr
from .tokens.doc cimport Doc
from .vocab cimport Vocab
from libcpp.vector cimport vector
from .attrs import FLAG61 as U_ENT
from .attrs import FLAG60 as B2_ENT
@ -221,8 +219,7 @@ cdef class Matcher:
q = 0
# Go over the open matches, extending or finalizing if able. Otherwise,
# we over-write them (q doesn't advance)
for i in range(partials.size()):
state = partials.at(i)
for state in partials:
if match(state, token):
if is_final(state):
label, start, end = get_entity(state, token, token_i)
@ -233,8 +230,7 @@ cdef class Matcher:
q += 1
partials.resize(q)
# Check whether we open any new patterns on this token
for i in range(self.n_patterns):
state = self.patterns[i]
for state in self.patterns:
if match(state, token):
if is_final(state):
label, start, end = get_entity(state, token, token_i)
@ -242,7 +238,16 @@ cdef class Matcher:
matches.append((label, start, end))
else:
partials.push_back(state + 1)
doc.ents = [(e.label, e.start, e.end) for e in doc.ents] + matches
seen = set()
filtered = []
for label, start, end in sorted(matches, key=lambda m: (m[1], -(m[1] - m[2]))):
if all(i in seen for i in range(start, end)):
continue
else:
for i in range(start, end):
seen.add(i)
filtered.append((label, start, end))
doc.ents = [(e.label, e.start, e.end) for e in doc.ents] + filtered
return matches

View File

@ -72,6 +72,10 @@ cdef class Tokenizer:
Returns:
tokens (Doc): A Doc object, giving access to a sequence of LexemeCs.
"""
if len(string) >= (2 ** 30):
raise ValueError(
"String is too long: %d characters. Max is 2**30." % len(string)
)
cdef int length = len(string)
cdef Doc tokens = Doc(self.vocab)
if length == 0:

View File

@ -447,9 +447,9 @@ cdef class Doc:
cdef Span span = self[start:end]
# Get LexemeC for newly merged token
new_orth = ''.join([t.string for t in span])
new_orth = ''.join([t.text_with_ws for t in span])
if span[-1].whitespace_:
new_orth = new_orth[:-1]
new_orth = new_orth[:-len(span[-1].whitespace_)]
cdef const LexemeC* lex = self.vocab.get(self.mem, new_orth)
# House the new merged token where it starts
cdef TokenC* token = &self.data[start]
@ -508,16 +508,26 @@ cdef int set_children_from_heads(TokenC* tokens, int length) except -1:
cdef TokenC* head
cdef TokenC* child
cdef int i
# Set number of left/right children to 0. We'll increment it in the loops.
for i in range(length):
tokens[i].l_kids = 0
tokens[i].r_kids = 0
tokens[i].l_edge = i
tokens[i].r_edge = i
# Set left edges
for i in range(length):
child = &tokens[i]
head = &tokens[i + child.head]
if child < head and child.l_edge < head.l_edge:
if child < head:
if child.l_edge < head.l_edge:
head.l_edge = child.l_edge
head.l_kids += 1
# Set right edges --- same as above, but iterate in reverse
for i in range(length-1, -1, -1):
child = &tokens[i]
head = &tokens[i + child.head]
if child > head and child.r_edge > head.r_edge:
if child > head:
if child.r_edge > head.r_edge:
head.r_edge = child.r_edge
head.r_kids += 1

View File

@ -278,7 +278,7 @@ cdef class Token:
property whitespace_:
def __get__(self):
return self.string[self.c.lex.length:]
return ' ' if self.c.spacy else ''
property orth_:
def __get__(self):

View File

@ -1,17 +1,102 @@
import pytest
from spacy.matcher import Matcher
from spacy.attrs import LOWER
@pytest.mark.xfail
def test_overlap_issue118(EN):
'''Test a bug that arose from having overlapping matches'''
doc = EN.tokenizer(u'how many points did lebron james score against the boston celtics last night')
ORG = doc.vocab.strings['ORG']
matcher = Matcher(EN.vocab, {'BostonCeltics': ('ORG', {}, [[{'lower': 'boston'}, {'lower': 'celtics'}], [{'lower': 'celtics'}]])})
matcher = Matcher(EN.vocab,
{'BostonCeltics':
('ORG', {},
[
[{LOWER: 'celtics'}],
[{LOWER: 'boston'}, {LOWER: 'celtics'}],
]
)
}
)
assert len(list(doc.ents)) == 0
matches = matcher(doc)
assert matches == [(ORG, 9, 11)]
assert matches == [(ORG, 9, 11), (ORG, 10, 11)]
ents = list(doc.ents)
assert len(ents) == 1
assert ents[0].label == ORG
assert ents[0].start == 9
assert ents[0].end == 11
def test_overlap_reorder(EN):
'''Test order dependence'''
doc = EN.tokenizer(u'how many points did lebron james score against the boston celtics last night')
ORG = doc.vocab.strings['ORG']
matcher = Matcher(EN.vocab,
{'BostonCeltics':
('ORG', {},
[
[{LOWER: 'boston'}, {LOWER: 'celtics'}],
[{LOWER: 'celtics'}],
]
)
}
)
assert len(list(doc.ents)) == 0
matches = matcher(doc)
assert matches == [(ORG, 9, 11), (ORG, 10, 11)]
ents = list(doc.ents)
assert len(ents) == 1
assert ents[0].label == ORG
assert ents[0].start == 9
assert ents[0].end == 11
def test_overlap_prefix(EN):
'''Test order dependence'''
doc = EN.tokenizer(u'how many points did lebron james score against the boston celtics last night')
ORG = doc.vocab.strings['ORG']
matcher = Matcher(EN.vocab,
{'BostonCeltics':
('ORG', {},
[
[{LOWER: 'boston'}],
[{LOWER: 'boston'}, {LOWER: 'celtics'}],
]
)
}
)
assert len(list(doc.ents)) == 0
matches = matcher(doc)
assert matches == [(ORG, 9, 10), (ORG, 9, 11)]
ents = list(doc.ents)
assert len(ents) == 1
assert ents[0].label == ORG
assert ents[0].start == 9
assert ents[0].end == 11
def test_overlap_prefix_reorder(EN):
'''Test order dependence'''
doc = EN.tokenizer(u'how many points did lebron james score against the boston celtics last night')
ORG = doc.vocab.strings['ORG']
matcher = Matcher(EN.vocab,
{'BostonCeltics':
('ORG', {},
[
[{LOWER: 'boston'}, {LOWER: 'celtics'}],
[{LOWER: 'boston'}],
]
)
}
)
assert len(list(doc.ents)) == 0
matches = matcher(doc)
assert matches == [(ORG, 9, 10), (ORG, 9, 11)]
ents = list(doc.ents)
assert len(ents) == 1
assert ents[0].label == ORG

View File

@ -7,6 +7,10 @@ def test_hyphen(en_tokenizer):
assert len(tokens) == 3
def test_numeric_range(en_tokenizer):
tokens = en_tokenizer('0.1-13.5')
assert len(tokens) == 3
def test_period(en_tokenizer):
tokens = en_tokenizer('best.Known')
assert len(tokens) == 3

View File

@ -109,3 +109,42 @@ def test_set_ents(EN):
assert ent.label_ == 'PRODUCT'
assert ent.start == 2
assert ent.end == 4
def test_merge(EN):
doc = EN('WKRO played songs by the beach boys all night')
assert len(doc) == 9
# merge 'The Beach Boys'
doc.merge(doc[4].idx, doc[6].idx + len(doc[6]), 'NAMED', 'LEMMA', 'TYPE')
assert len(doc) == 7
assert doc[4].text == 'the beach boys'
assert doc[4].text_with_ws == 'the beach boys '
assert doc[4].tag_ == 'NAMED'
def test_merge_end_string(EN):
doc = EN('WKRO played songs by the beach boys all night')
assert len(doc) == 9
# merge 'The Beach Boys'
doc.merge(doc[7].idx, doc[8].idx + len(doc[8]), 'NAMED', 'LEMMA', 'TYPE')
assert len(doc) == 8
assert doc[7].text == 'all night'
assert doc[7].text_with_ws == 'all night'
@pytest.mark.models
def test_merge_children(EN):
"""Test that attachments work correctly after merging."""
doc = EN('WKRO played songs by the beach boys all night')
# merge 'The Beach Boys'
doc.merge(doc[4].idx, doc[6].idx + len(doc[6]), 'NAMED', 'LEMMA', 'TYPE')
for word in doc:
if word.i < word.head.i:
assert word in list(word.head.lefts)
elif word.i > word.head.i:
assert word in list(word.head.rights)

View File

@ -1,8 +1,11 @@
#!/usr/bin/env python
import sys
import re
from __future__ import unicode_literals
import os
import ast
import io
import plac
# cgi.escape is deprecated since py32
try:
@ -11,18 +14,20 @@ except ImportError:
from cgi import escape
src_dirname = sys.argv[1]
dst_dirname = sys.argv[2]
# e.g. python website/create_code_samples tests/website/ website/src/
def main(src_dirname, dst_dirname):
prefix = "test_"
for filename in os.listdir(src_dirname):
match = re.match(re.escape(prefix) + r"(.+)\.py$", filename)
if not match:
if not filename.startswith('test_'):
continue
if not filename.endswith('.py'):
continue
name = match.group(1)
source = open(os.path.join(src_dirname, filename)).readlines()
# Remove test_ prefix and .py suffix
name = filename[6:-3]
with io.open(os.path.join(src_dirname, filename), 'r', encoding='utf8') as file_:
source = file_.readlines()
tree = ast.parse("".join(source))
for root in tree.body:
@ -61,5 +66,10 @@ for filename in os.listdir(src_dirname):
code_filename = "%s.%s" % (name, root.name[len(prefix):])
with open(os.path.join(dst_dirname, code_filename), "w") as f:
with io.open(os.path.join(dst_dirname, code_filename),
"w", encoding='utf8') as f:
f.write(escape("".join(body)))
if __name__ == '__main__':
plac.call(main)