pull changes

This commit is contained in:
Andreas Grivas 2015-11-04 19:56:27 +02:00
commit cb376eadd3
24 changed files with 159 additions and 143 deletions

7
fabfile.py vendored
View File

@ -54,10 +54,10 @@ def prebuild(build_dir='/tmp/build_spacy'):
local('pip install --no-cache-dir -r requirements.txt') local('pip install --no-cache-dir -r requirements.txt')
local('fab clean make') local('fab clean make')
local('cp -r %s/corpora/en/wordnet corpora/en/' % spacy_dir) local('cp -r %s/corpora/en/wordnet corpora/en/' % spacy_dir)
local('cp %s/corpora/en/freqs.txt.gz corpora/en/' % spacy_dir)
local('PYTHONPATH=`pwd` python bin/init_model.py en lang_data corpora spacy/en/data') local('PYTHONPATH=`pwd` python bin/init_model.py en lang_data corpora spacy/en/data')
local('fab test') local('fab test')
local('python setup.py sdist') local('PYTHONPATH=`pwd` python -m spacy.en.download --force all')
local('py.test --models spacy/tests/')
def docs(): def docs():
@ -121,9 +121,8 @@ def clean():
def test(): def test():
with virtualenv(VENV_DIR): with virtualenv(VENV_DIR):
# Run each test file separately. pytest is performing poorly, not sure why
with lcd(path.dirname(__file__)): with lcd(path.dirname(__file__)):
local('py.test -x tests/') local('py.test -x spacy/tests')
def train(json_dir=None, dev_loc=None, model_dir=None): def train(json_dir=None, dev_loc=None, model_dir=None):

View File

@ -13,6 +13,21 @@ from distutils.command.build_ext import build_ext
import platform import platform
PACKAGE_DATA = {
"spacy": ["*.pxd"],
"spacy.tokens": ["*.pxd"],
"spacy.serialize": ["*.pxd"],
"spacy.syntax": ["*.pxd"],
"spacy.en": [
"*.pxd",
"data/wordnet/*.exc",
"data/wordnet/index.*",
"data/tokenizer/*",
"data/vocab/serializer.json"
]
}
# By subclassing build_extensions we have the actual compiler that will be used which is really known only after finalize_options # By subclassing build_extensions we have the actual compiler that will be used which is really known only after finalize_options
# http://stackoverflow.com/questions/724664/python-distutils-how-to-get-a-compiler-that-is-going-to-be-used # http://stackoverflow.com/questions/724664/python-distutils-how-to-get-a-compiler-that-is-going-to-be-used
compile_options = {'msvc' : ['/Ox', '/EHsc'] , compile_options = {'msvc' : ['/Ox', '/EHsc'] ,
@ -81,6 +96,8 @@ except OSError:
pass pass
def clean(mod_names): def clean(mod_names):
for name in mod_names: for name in mod_names:
name = name.replace('.', '/') name = name.replace('.', '/')
@ -128,15 +145,7 @@ def cython_setup(mod_names, language, includes):
author_email='honnibal@gmail.com', author_email='honnibal@gmail.com',
version=VERSION, version=VERSION,
url="http://honnibal.github.io/spaCy/", url="http://honnibal.github.io/spaCy/",
package_data={"spacy": ["*.pxd", "tests/*.py", "tests/*/*.py"], package_data=PACKAGE_DATA,
"spacy.tokens": ["*.pxd"],
"spacy.serialize": ["*.pxd"],
"spacy.en": ["*.pxd", "data/pos/*",
"data/wordnet/*", "data/tokenizer/*",
"data/vocab/tag_map.json",
"data/vocab/lexemes.bin",
"data/vocab/strings.json"],
"spacy.syntax": ["*.pxd"]},
ext_modules=exts, ext_modules=exts,
cmdclass={'build_ext': build_ext_cython_subclass}, cmdclass={'build_ext': build_ext_cython_subclass},
license="MIT", license="MIT",
@ -154,7 +163,7 @@ def run_setup(exts):
'spacy.tests.munge', 'spacy.tests.munge',
'spacy.tests.parser', 'spacy.tests.parser',
'spacy.tests.serialize', 'spacy.tests.serialize',
'spacy.tests.spans', 'spacy.tests.span',
'spacy.tests.tagger', 'spacy.tests.tagger',
'spacy.tests.tokenizer', 'spacy.tests.tokenizer',
'spacy.tests.tokens', 'spacy.tests.tokens',
@ -165,18 +174,11 @@ def run_setup(exts):
author_email='honnibal@gmail.com', author_email='honnibal@gmail.com',
version=VERSION, version=VERSION,
url="http://honnibal.github.io/spaCy/", url="http://honnibal.github.io/spaCy/",
package_data={"spacy": ["*.pxd"], package_data=PACKAGE_DATA,
"spacy.en": ["*.pxd", "data/pos/*",
"data/wordnet/*", "data/tokenizer/*",
"data/vocab/lexemes.bin",
"data/vocab/serializer.json",
"data/vocab/oov_prob",
"data/vocab/strings.txt"],
"spacy.syntax": ["*.pxd"]},
ext_modules=exts, ext_modules=exts,
license="MIT", license="MIT",
install_requires=['numpy', 'murmurhash', 'cymem >= 1.30', 'preshed >= 0.43', install_requires=['numpy', 'murmurhash', 'cymem == 1.30', 'preshed == 0.43',
'thinc >= 3.4.1', "text_unidecode", 'plac', 'six', 'thinc == 3.4.1', "text_unidecode", 'plac', 'six',
'ujson', 'cloudpickle'], 'ujson', 'cloudpickle'],
setup_requires=["headers_workaround"], setup_requires=["headers_workaround"],
cmdclass = {'build_ext': build_ext_subclass }, cmdclass = {'build_ext': build_ext_subclass },
@ -189,7 +191,7 @@ def run_setup(exts):
headers_workaround.install_headers('numpy') headers_workaround.install_headers('numpy')
VERSION = '0.97' VERSION = '0.99'
def main(modules, is_pypy): def main(modules, is_pypy):
language = "cpp" language = "cpp"
includes = ['.', path.join(sys.prefix, 'include')] includes = ['.', path.join(sys.prefix, 'include')]
@ -215,7 +217,7 @@ MOD_NAMES = ['spacy.parts_of_speech', 'spacy.strings',
'spacy.syntax.arc_eager', 'spacy.syntax.arc_eager',
'spacy.syntax._parse_features', 'spacy.syntax._parse_features',
'spacy.gold', 'spacy.orth', 'spacy.gold', 'spacy.orth',
'spacy.tokens.doc', 'spacy.tokens.spans', 'spacy.tokens.token', 'spacy.tokens.doc', 'spacy.tokens.span', 'spacy.tokens.token',
'spacy.serialize.packer', 'spacy.serialize.huffman', 'spacy.serialize.bits', 'spacy.serialize.packer', 'spacy.serialize.huffman', 'spacy.serialize.bits',
'spacy.cfile', 'spacy.matcher', 'spacy.cfile', 'spacy.matcher',
'spacy.syntax.ner', 'spacy.syntax.ner',

View File

@ -38,6 +38,7 @@ def install_data(url, extract_path, download_path):
assert tmp == download_path assert tmp == download_path
t = tarfile.open(download_path) t = tarfile.open(download_path)
t.extractall(extract_path) t.extractall(extract_path)
os.unlink(download_path)
@plac.annotations( @plac.annotations(

View File

@ -88,7 +88,7 @@ class Language(object):
return orth.like_url(string) return orth.like_url(string)
@staticmethod @staticmethod
def like_number(string): def like_num(string):
return orth.like_number(string) return orth.like_number(string)
@staticmethod @staticmethod
@ -119,7 +119,7 @@ class Language(object):
attrs.IS_TITLE: cls.is_title, attrs.IS_TITLE: cls.is_title,
attrs.IS_UPPER: cls.is_upper, attrs.IS_UPPER: cls.is_upper,
attrs.LIKE_URL: cls.like_url, attrs.LIKE_URL: cls.like_url,
attrs.LIKE_NUM: cls.like_number, attrs.LIKE_NUM: cls.like_num,
attrs.LIKE_EMAIL: cls.like_email, attrs.LIKE_EMAIL: cls.like_email,
attrs.IS_STOP: cls.is_stop, attrs.IS_STOP: cls.is_stop,
attrs.IS_OOV: lambda string: True attrs.IS_OOV: lambda string: True

View File

@ -51,7 +51,7 @@ cdef class Lexeme:
def __get__(self): def __get__(self):
cdef int i cdef int i
for i in range(self.vocab.vectors_length): for i in range(self.vocab.vectors_length):
if self.c.repvec[i] != 0: if self.c.vector[i] != 0:
return True return True
else: else:
return False return False
@ -74,14 +74,14 @@ cdef class Lexeme:
"to install the data." "to install the data."
) )
repvec_view = <float[:length,]>self.c.repvec vector_view = <float[:length,]>self.c.vector
return numpy.asarray(repvec_view) return numpy.asarray(vector_view)
def __set__(self, vector): def __set__(self, vector):
assert len(vector) == self.vocab.vectors_length assert len(vector) == self.vocab.vectors_length
cdef float value cdef float value
for i, value in enumerate(vector): for i, value in enumerate(vector):
self.c.repvec[i] = value self.c.vector[i] = value
property repvec: property repvec:
def __get__(self): def __get__(self):

View File

@ -215,7 +215,7 @@ cdef class Matcher:
cdef Pattern* state cdef Pattern* state
matches = [] matches = []
for token_i in range(doc.length): for token_i in range(doc.length):
token = &doc.data[token_i] token = &doc.c[token_i]
q = 0 q = 0
# Go over the open matches, extending or finalizing if able. Otherwise, # Go over the open matches, extending or finalizing if able. Otherwise,
# we over-write them (q doesn't advance) # we over-write them (q doesn't advance)
@ -286,7 +286,7 @@ cdef class PhraseMatcher:
for i in range(self.max_length): for i in range(self.max_length):
self._phrase_key[i] = 0 self._phrase_key[i] = 0
for i, tag in enumerate(tags): for i, tag in enumerate(tags):
lexeme = self.vocab[tokens.data[i].lex.orth] lexeme = self.vocab[tokens.c[i].lex.orth]
lexeme.set_flag(tag, True) lexeme.set_flag(tag, True)
self._phrase_key[i] = lexeme.orth self._phrase_key[i] = lexeme.orth
cdef hash_t key = hash64(self._phrase_key, self.max_length * sizeof(attr_t), 0) cdef hash_t key = hash64(self._phrase_key, self.max_length * sizeof(attr_t), 0)
@ -309,7 +309,7 @@ cdef class PhraseMatcher:
for i in range(self.max_length): for i in range(self.max_length):
self._phrase_key[i] = 0 self._phrase_key[i] = 0
for i, j in enumerate(range(start, end)): for i, j in enumerate(range(start, end)):
self._phrase_key[i] = doc.data[j].lex.orth self._phrase_key[i] = doc.c[j].lex.orth
cdef hash_t key = hash64(self._phrase_key, self.max_length * sizeof(attr_t), 0) cdef hash_t key = hash64(self._phrase_key, self.max_length * sizeof(attr_t), 0)
if self.phrase_ids.get(key): if self.phrase_ids.get(key):
return True return True

View File

@ -38,6 +38,8 @@ cdef class Morphology:
tag_id = self.reverse_index[self.strings[tag]] tag_id = self.reverse_index[self.strings[tag]]
else: else:
tag_id = tag tag_id = tag
if tag_id >= self.n_tags:
raise ValueError("Unknown tag: %s" % tag)
analysis = <MorphAnalysisC*>self._cache.get(tag_id, token.lex.orth) analysis = <MorphAnalysisC*>self._cache.get(tag_id, token.lex.orth)
if analysis is NULL: if analysis is NULL:
analysis = <MorphAnalysisC*>self.mem.alloc(1, sizeof(MorphAnalysisC)) analysis = <MorphAnalysisC*>self.mem.alloc(1, sizeof(MorphAnalysisC))
@ -86,6 +88,8 @@ cdef class Morphology:
return orth return orth
cdef unicode py_string = self.strings[orth] cdef unicode py_string = self.strings[orth]
if pos != NOUN and pos != VERB and pos != ADJ and pos != PUNCT: if pos != NOUN and pos != VERB and pos != ADJ and pos != PUNCT:
# TODO: This should lower-case
# return self.strings[py_string.lower()]
return orth return orth
cdef set lemma_strings cdef set lemma_strings
cdef unicode lemma_string cdef unicode lemma_string

View File

@ -155,10 +155,10 @@ cdef class Packer:
self.char_codec.encode(bytearray(utf8_str), bits) self.char_codec.encode(bytearray(utf8_str), bits)
cdef int i, j cdef int i, j
for i in range(doc.length): for i in range(doc.length):
for j in range(doc.data[i].lex.length-1): for j in range(doc.c[i].lex.length-1):
bits.append(False) bits.append(False)
bits.append(True) bits.append(True)
if doc.data[i].spacy: if doc.c[i].spacy:
bits.append(False) bits.append(False)
return bits return bits

View File

@ -17,10 +17,7 @@ try:
except ImportError: except ImportError:
import io import io
try: import ujson as json
import ujson as json
except ImportError:
import json
cpdef hash_t hash_string(unicode string) except 0: cpdef hash_t hash_string(unicode string) except 0:

View File

@ -5,7 +5,7 @@ from .parts_of_speech cimport univ_pos_t
cdef struct LexemeC: cdef struct LexemeC:
float* repvec float* vector
flags_t flags flags_t flags
@ -32,18 +32,8 @@ cdef struct Entity:
int label int label
cdef struct Constituent:
const TokenC* head
const Constituent* parent
const Constituent* first
const Constituent* last
int label
int length
cdef struct TokenC: cdef struct TokenC:
const LexemeC* lex const LexemeC* lex
const Constituent* ctnt
uint64_t morph uint64_t morph
univ_pos_t pos univ_pos_t pos
bint spacy bint spacy

View File

@ -84,7 +84,7 @@ cdef class Parser:
return cls(strings, moves, model) return cls(strings, moves, model)
def __call__(self, Doc tokens): def __call__(self, Doc tokens):
cdef StateClass stcls = StateClass.init(tokens.data, tokens.length) cdef StateClass stcls = StateClass.init(tokens.c, tokens.length)
self.moves.initialize_state(stcls) self.moves.initialize_state(stcls)
cdef Example eg = Example(self.model.n_classes, CONTEXT_SIZE, cdef Example eg = Example(self.model.n_classes, CONTEXT_SIZE,
@ -112,7 +112,7 @@ cdef class Parser:
def train(self, Doc tokens, GoldParse gold): def train(self, Doc tokens, GoldParse gold):
self.moves.preprocess_gold(gold) self.moves.preprocess_gold(gold)
cdef StateClass stcls = StateClass.init(tokens.data, tokens.length) cdef StateClass stcls = StateClass.init(tokens.c, tokens.length)
self.moves.initialize_state(stcls) self.moves.initialize_state(stcls)
cdef Example eg = Example(self.model.n_classes, CONTEXT_SIZE, cdef Example eg = Example(self.model.n_classes, CONTEXT_SIZE,
self.model.n_feats, self.model.n_feats) self.model.n_feats, self.model.n_feats)
@ -143,7 +143,7 @@ cdef class StepwiseState:
def __init__(self, Parser parser, Doc doc): def __init__(self, Parser parser, Doc doc):
self.parser = parser self.parser = parser
self.doc = doc self.doc = doc
self.stcls = StateClass.init(doc.data, doc.length) self.stcls = StateClass.init(doc.c, doc.length)
self.parser.moves.initialize_state(self.stcls) self.parser.moves.initialize_state(self.stcls)
self.eg = Example(self.parser.model.n_classes, CONTEXT_SIZE, self.eg = Example(self.parser.model.n_classes, CONTEXT_SIZE,
self.parser.model.n_feats, self.parser.model.n_feats) self.parser.model.n_feats, self.parser.model.n_feats)

View File

@ -141,9 +141,9 @@ cdef class Tagger:
cdef int i cdef int i
cdef const weight_t* scores cdef const weight_t* scores
for i in range(tokens.length): for i in range(tokens.length):
if tokens.data[i].pos == 0: if tokens.c[i].pos == 0:
guess = self.predict(i, tokens.data) guess = self.predict(i, tokens.c)
self.vocab.morphology.assign_tag(&tokens.data[i], guess) self.vocab.morphology.assign_tag(&tokens.c[i], guess)
tokens.is_tagged = True tokens.is_tagged = True
tokens._py_tokens = [None] * tokens.length tokens._py_tokens = [None] * tokens.length
@ -154,7 +154,7 @@ cdef class Tagger:
def tag_from_strings(self, Doc tokens, object tag_strs): def tag_from_strings(self, Doc tokens, object tag_strs):
cdef int i cdef int i
for i in range(tokens.length): for i in range(tokens.length):
self.vocab.morphology.assign_tag(&tokens.data[i], tag_strs[i]) self.vocab.morphology.assign_tag(&tokens.c[i], tag_strs[i])
tokens.is_tagged = True tokens.is_tagged = True
tokens._py_tokens = [None] * tokens.length tokens._py_tokens = [None] * tokens.length
@ -170,13 +170,13 @@ cdef class Tagger:
[g for g in gold_tag_strs if g is not None and g not in self.tag_names]) [g for g in gold_tag_strs if g is not None and g not in self.tag_names])
correct = 0 correct = 0
for i in range(tokens.length): for i in range(tokens.length):
guess = self.update(i, tokens.data, golds[i]) guess = self.update(i, tokens.c, golds[i])
loss = golds[i] != -1 and guess != golds[i] loss = golds[i] != -1 and guess != golds[i]
self.vocab.morphology.assign_tag(&tokens.data[i], guess) self.vocab.morphology.assign_tag(&tokens.c[i], guess)
correct += loss == 0 correct += loss == 0
self.freqs[TAG][tokens.data[i].tag] += 1 self.freqs[TAG][tokens.c[i].tag] += 1
return correct return correct
cdef int predict(self, int i, const TokenC* tokens) except -1: cdef int predict(self, int i, const TokenC* tokens) except -1:

View File

@ -38,3 +38,14 @@ def test_left_right(EN):
for child in word.rights: for child in word.rights:
assert child.head.i == word.i assert child.head.i == word.i
@pytest.mark.models
def test_lemmas(EN):
orig = EN(u'The geese are flying')
result = Doc(orig.vocab).from_bytes(orig.to_bytes())
the, geese, are, flying = result
assert geese.lemma_ == 'goose'
assert are.lemma_ == 'be'
assert flying.lemma_ == 'fly'

View File

@ -6,6 +6,7 @@ import pytest
import numpy import numpy
from spacy.language import Language from spacy.language import Language
from spacy.en import English
from spacy.vocab import Vocab from spacy.vocab import Vocab
from spacy.tokens.doc import Doc from spacy.tokens.doc import Doc
from spacy.tokenizer import Tokenizer from spacy.tokenizer import Tokenizer
@ -20,7 +21,7 @@ from spacy.serialize.bits import BitArray
@pytest.fixture @pytest.fixture
def vocab(): def vocab():
vocab = Vocab(Language.default_lex_attrs()) vocab = English.default_vocab()
lex = vocab['dog'] lex = vocab['dog']
assert vocab[vocab.strings['dog']].orth_ == 'dog' assert vocab[vocab.strings['dog']].orth_ == 'dog'
lex = vocab['the'] lex = vocab['the']
@ -64,6 +65,7 @@ def test_packer_unannotated(tokenizer):
assert result.string == 'the dog jumped' assert result.string == 'the dog jumped'
@pytest.mark.models
def test_packer_annotated(tokenizer): def test_packer_annotated(tokenizer):
vocab = tokenizer.vocab vocab = tokenizer.vocab
nn = vocab.strings['NN'] nn = vocab.strings['NN']

View File

@ -5,6 +5,11 @@ import pickle
import pytest import pytest
import tempfile import tempfile
try:
unicode
except NameError:
unicode = str
@pytest.mark.models @pytest.mark.models
def test_pickle_english(EN): def test_pickle_english(EN):
file_ = io.BytesIO() file_ = io.BytesIO()
@ -21,7 +26,7 @@ def test_cloudpickle_to_file(EN):
p = cloudpickle.CloudPickler(f) p = cloudpickle.CloudPickler(f)
p.dump(EN) p.dump(EN)
f.close() f.close()
loaded_en = cloudpickle.load(open(f.name)) loaded_en = cloudpickle.load(open(f.name, 'rb'))
os.unlink(f.name) os.unlink(f.name)
doc = loaded_en(unicode('test parse')) doc = loaded_en(unicode('test parse'))
assert len(doc) == 2 assert len(doc) == 2

View File

@ -113,7 +113,7 @@ cdef class Tokenizer:
self._tokenize(tokens, span, key) self._tokenize(tokens, span, key)
in_ws = not in_ws in_ws = not in_ws
if uc == ' ': if uc == ' ':
tokens.data[tokens.length - 1].spacy = True tokens.c[tokens.length - 1].spacy = True
start = i + 1 start = i + 1
else: else:
start = i start = i
@ -125,7 +125,7 @@ cdef class Tokenizer:
cache_hit = self._try_cache(key, tokens) cache_hit = self._try_cache(key, tokens)
if not cache_hit: if not cache_hit:
self._tokenize(tokens, span, key) self._tokenize(tokens, span, key)
tokens.data[tokens.length - 1].spacy = string[-1] == ' ' tokens.c[tokens.length - 1].spacy = string[-1] == ' '
return tokens return tokens
cdef int _try_cache(self, hash_t key, Doc tokens) except -1: cdef int _try_cache(self, hash_t key, Doc tokens) except -1:
@ -148,7 +148,7 @@ cdef class Tokenizer:
orig_size = tokens.length orig_size = tokens.length
span = self._split_affixes(span, &prefixes, &suffixes) span = self._split_affixes(span, &prefixes, &suffixes)
self._attach_tokens(tokens, span, &prefixes, &suffixes) self._attach_tokens(tokens, span, &prefixes, &suffixes)
self._save_cached(&tokens.data[orig_size], orig_key, tokens.length - orig_size) self._save_cached(&tokens.c[orig_size], orig_key, tokens.length - orig_size)
cdef unicode _split_affixes(self, unicode string, vector[const LexemeC*] *prefixes, cdef unicode _split_affixes(self, unicode string, vector[const LexemeC*] *prefixes,
vector[const LexemeC*] *suffixes): vector[const LexemeC*] *suffixes):

View File

@ -1,5 +1,5 @@
from .doc import Doc from .doc import Doc
from .token import Token from .token import Token
from .spans import Span from .span import Span
__all__ = [Doc, Token, Span] __all__ = [Doc, Token, Span]

View File

@ -26,7 +26,7 @@ cdef class Doc:
cdef public object _vector cdef public object _vector
cdef public object _vector_norm cdef public object _vector_norm
cdef TokenC* data cdef TokenC* c
cdef public bint is_tagged cdef public bint is_tagged
cdef public bint is_parsed cdef public bint is_parsed

View File

@ -18,7 +18,7 @@ from ..attrs cimport POS, LEMMA, TAG, DEP, HEAD, SPACY, ENT_IOB, ENT_TYPE
from ..parts_of_speech cimport CONJ, PUNCT, NOUN from ..parts_of_speech cimport CONJ, PUNCT, NOUN
from ..parts_of_speech cimport univ_pos_t from ..parts_of_speech cimport univ_pos_t
from ..lexeme cimport Lexeme from ..lexeme cimport Lexeme
from .spans cimport Span from .span cimport Span
from .token cimport Token from .token cimport Token
from ..serialize.bits cimport BitArray from ..serialize.bits cimport BitArray
from ..util import normalize_slice from ..util import normalize_slice
@ -73,7 +73,7 @@ cdef class Doc:
data_start[i].lex = &EMPTY_LEXEME data_start[i].lex = &EMPTY_LEXEME
data_start[i].l_edge = i data_start[i].l_edge = i
data_start[i].r_edge = i data_start[i].r_edge = i
self.data = data_start + PADDING self.c = data_start + PADDING
self.max_length = size self.max_length = size
self.length = 0 self.length = 0
self.is_tagged = False self.is_tagged = False
@ -97,7 +97,7 @@ cdef class Doc:
if self._py_tokens[i] is not None: if self._py_tokens[i] is not None:
return self._py_tokens[i] return self._py_tokens[i]
else: else:
return Token.cinit(self.vocab, &self.data[i], i, self) return Token.cinit(self.vocab, &self.c[i], i, self)
def __iter__(self): def __iter__(self):
"""Iterate over the tokens. """Iterate over the tokens.
@ -110,7 +110,7 @@ cdef class Doc:
if self._py_tokens[i] is not None: if self._py_tokens[i] is not None:
yield self._py_tokens[i] yield self._py_tokens[i]
else: else:
yield Token.cinit(self.vocab, &self.data[i], i, self) yield Token.cinit(self.vocab, &self.c[i], i, self)
def __len__(self): def __len__(self):
return self.length return self.length
@ -134,10 +134,6 @@ cdef class Doc:
return 0.0 return 0.0
return numpy.dot(self.vector, other.vector) / (self.vector_norm * other.vector_norm) return numpy.dot(self.vector, other.vector) / (self.vector_norm * other.vector_norm)
property repvec:
def __get__(self):
return self.vector
property vector: property vector:
def __get__(self): def __get__(self):
if self._vector is None: if self._vector is None:
@ -191,7 +187,7 @@ cdef class Doc:
cdef int label = 0 cdef int label = 0
output = [] output = []
for i in range(self.length): for i in range(self.length):
token = &self.data[i] token = &self.c[i]
if token.ent_iob == 1: if token.ent_iob == 1:
assert start != -1 assert start != -1
elif token.ent_iob == 2 or token.ent_iob == 0: elif token.ent_iob == 2 or token.ent_iob == 0:
@ -216,23 +212,23 @@ cdef class Doc:
# 4. Test more nuanced date and currency regex # 4. Test more nuanced date and currency regex
cdef int i cdef int i
for i in range(self.length): for i in range(self.length):
self.data[i].ent_type = 0 self.c[i].ent_type = 0
self.data[i].ent_iob = 0 self.c[i].ent_iob = 0
cdef attr_t ent_type cdef attr_t ent_type
cdef int start, end cdef int start, end
for ent_type, start, end in ents: for ent_type, start, end in ents:
if ent_type is None or ent_type < 0: if ent_type is None or ent_type < 0:
# Mark as O # Mark as O
for i in range(start, end): for i in range(start, end):
self.data[i].ent_type = 0 self.c[i].ent_type = 0
self.data[i].ent_iob = 2 self.c[i].ent_iob = 2
else: else:
# Mark (inside) as I # Mark (inside) as I
for i in range(start, end): for i in range(start, end):
self.data[i].ent_type = ent_type self.c[i].ent_type = ent_type
self.data[i].ent_iob = 1 self.c[i].ent_iob = 1
# Set start as B # Set start as B
self.data[start].ent_iob = 3 self.c[start].ent_iob = 3
@property @property
def noun_chunks(self): def noun_chunks(self):
@ -249,7 +245,7 @@ cdef class Doc:
np_deps = [self.vocab.strings[label] for label in labels] np_deps = [self.vocab.strings[label] for label in labels]
np_label = self.vocab.strings['NP'] np_label = self.vocab.strings['NP']
for i in range(self.length): for i in range(self.length):
word = &self.data[i] word = &self.c[i]
if word.pos == NOUN and word.dep in np_deps: if word.pos == NOUN and word.dep in np_deps:
yield Span(self, word.l_edge, i+1, label=np_label) yield Span(self, word.l_edge, i+1, label=np_label)
@ -267,7 +263,7 @@ cdef class Doc:
cdef int i cdef int i
start = 0 start = 0
for i in range(1, self.length): for i in range(1, self.length):
if self.data[i].sent_start: if self.c[i].sent_start:
yield Span(self, start, i) yield Span(self, start, i)
start = i start = i
yield Span(self, start, self.length) yield Span(self, start, self.length)
@ -275,7 +271,7 @@ cdef class Doc:
cdef int push_back(self, LexemeOrToken lex_or_tok, bint has_space) except -1: cdef int push_back(self, LexemeOrToken lex_or_tok, bint has_space) except -1:
if self.length == self.max_length: if self.length == self.max_length:
self._realloc(self.length * 2) self._realloc(self.length * 2)
cdef TokenC* t = &self.data[self.length] cdef TokenC* t = &self.c[self.length]
if LexemeOrToken is const_TokenC_ptr: if LexemeOrToken is const_TokenC_ptr:
t[0] = lex_or_tok[0] t[0] = lex_or_tok[0]
else: else:
@ -314,7 +310,7 @@ cdef class Doc:
output = numpy.ndarray(shape=(self.length, len(attr_ids)), dtype=numpy.int32) output = numpy.ndarray(shape=(self.length, len(attr_ids)), dtype=numpy.int32)
for i in range(self.length): for i in range(self.length):
for j, feature in enumerate(attr_ids): for j, feature in enumerate(attr_ids):
output[i, j] = get_token_attr(&self.data[i], feature) output[i, j] = get_token_attr(&self.c[i], feature)
return output return output
def count_by(self, attr_id_t attr_id, exclude=None, PreshCounter counts=None): def count_by(self, attr_id_t attr_id, exclude=None, PreshCounter counts=None):
@ -344,11 +340,11 @@ cdef class Doc:
# Take this check out of the loop, for a bit of extra speed # Take this check out of the loop, for a bit of extra speed
if exclude is None: if exclude is None:
for i in range(self.length): for i in range(self.length):
counts.inc(get_token_attr(&self.data[i], attr_id), 1) counts.inc(get_token_attr(&self.c[i], attr_id), 1)
else: else:
for i in range(self.length): for i in range(self.length):
if not exclude(self[i]): if not exclude(self[i]):
attr = get_token_attr(&self.data[i], attr_id) attr = get_token_attr(&self.c[i], attr_id)
counts.inc(attr, 1) counts.inc(attr, 1)
if output_dict: if output_dict:
return dict(counts) return dict(counts)
@ -361,12 +357,12 @@ cdef class Doc:
# words out-of-bounds, and get out-of-bounds markers. # words out-of-bounds, and get out-of-bounds markers.
# Now that we want to realloc, we need the address of the true start, # Now that we want to realloc, we need the address of the true start,
# so we jump the pointer back PADDING places. # so we jump the pointer back PADDING places.
cdef TokenC* data_start = self.data - PADDING cdef TokenC* data_start = self.c - PADDING
data_start = <TokenC*>self.mem.realloc(data_start, n * sizeof(TokenC)) data_start = <TokenC*>self.mem.realloc(data_start, n * sizeof(TokenC))
self.data = data_start + PADDING self.c = data_start + PADDING
cdef int i cdef int i
for i in range(self.length, self.max_length + PADDING): for i in range(self.length, self.max_length + PADDING):
self.data[i].lex = &EMPTY_LEXEME self.c[i].lex = &EMPTY_LEXEME
cdef int set_parse(self, const TokenC* parsed) except -1: cdef int set_parse(self, const TokenC* parsed) except -1:
# TODO: This method is fairly misleading atm. It's used by Parser # TODO: This method is fairly misleading atm. It's used by Parser
@ -375,14 +371,14 @@ cdef class Doc:
# Probably we should use from_array? # Probably we should use from_array?
self.is_parsed = True self.is_parsed = True
for i in range(self.length): for i in range(self.length):
self.data[i] = parsed[i] self.c[i] = parsed[i]
assert self.data[i].l_edge <= i assert self.c[i].l_edge <= i
assert self.data[i].r_edge >= i assert self.c[i].r_edge >= i
def from_array(self, attrs, array): def from_array(self, attrs, array):
cdef int i, col cdef int i, col
cdef attr_id_t attr_id cdef attr_id_t attr_id
cdef TokenC* tokens = self.data cdef TokenC* tokens = self.c
cdef int length = len(array) cdef int length = len(array)
cdef attr_t[:] values cdef attr_t[:] values
for col, attr_id in enumerate(attrs): for col, attr_id in enumerate(attrs):
@ -398,7 +394,8 @@ cdef class Doc:
self.is_parsed = True self.is_parsed = True
elif attr_id == TAG: elif attr_id == TAG:
for i in range(length): for i in range(length):
tokens[i].tag = values[i] self.vocab.morphology.assign_tag(&tokens[i],
self.vocab.morphology.reverse_index[values[i]])
if not self.is_tagged and tokens[i].tag != 0: if not self.is_tagged and tokens[i].tag != 0:
self.is_tagged = True self.is_tagged = True
elif attr_id == POS: elif attr_id == POS:
@ -413,7 +410,9 @@ cdef class Doc:
elif attr_id == ENT_TYPE: elif attr_id == ENT_TYPE:
for i in range(length): for i in range(length):
tokens[i].ent_type = values[i] tokens[i].ent_type = values[i]
set_children_from_heads(self.data, self.length) else:
raise ValueError("Unknown attribute ID: %d" % attr_id)
set_children_from_heads(self.c, self.length)
return self return self
def to_bytes(self): def to_bytes(self):
@ -463,9 +462,9 @@ cdef class Doc:
cdef int start = -1 cdef int start = -1
cdef int end = -1 cdef int end = -1
for i in range(self.length): for i in range(self.length):
if self.data[i].idx == start_idx: if self.c[i].idx == start_idx:
start = i start = i
if (self.data[i].idx + self.data[i].lex.length) == end_idx: if (self.c[i].idx + self.c[i].lex.length) == end_idx:
if start == -1: if start == -1:
return None return None
end = i + 1 end = i + 1
@ -488,11 +487,12 @@ cdef class Doc:
new_orth = new_orth[:-len(span[-1].whitespace_)] new_orth = new_orth[:-len(span[-1].whitespace_)]
cdef const LexemeC* lex = self.vocab.get(self.mem, new_orth) cdef const LexemeC* lex = self.vocab.get(self.mem, new_orth)
# House the new merged token where it starts # House the new merged token where it starts
cdef TokenC* token = &self.data[start] cdef TokenC* token = &self.c[start]
token.spacy = self.data[end-1].spacy token.spacy = self.c[end-1].spacy
# What to do about morphology?? if tag in self.vocab.morphology.tag_map:
# TODO: token.morph = ??? self.vocab.morphology.assign_tag(token, tag)
token.tag = self.vocab.strings[tag] else:
token.tag = self.vocab.strings[tag]
token.lemma = self.vocab.strings[lemma] token.lemma = self.vocab.strings[lemma]
if ent_type == 'O': if ent_type == 'O':
token.ent_iob = 2 token.ent_iob = 2
@ -511,31 +511,31 @@ cdef class Doc:
# as it modifies the character offsets in the doc # as it modifies the character offsets in the doc
token.lex = lex token.lex = lex
for i in range(self.length): for i in range(self.length):
self.data[i].head += i self.c[i].head += i
# Set the head of the merged token, and its dep relation, from the Span # Set the head of the merged token, and its dep relation, from the Span
token.head = self.data[span_root].head token.head = self.c[span_root].head
# Adjust deps before shrinking tokens # Adjust deps before shrinking tokens
# Tokens which point into the merged token should now point to it # Tokens which point into the merged token should now point to it
# Subtract the offset from all tokens which point to >= end # Subtract the offset from all tokens which point to >= end
offset = (end - start) - 1 offset = (end - start) - 1
for i in range(self.length): for i in range(self.length):
head_idx = self.data[i].head head_idx = self.c[i].head
if start <= head_idx < end: if start <= head_idx < end:
self.data[i].head = start self.c[i].head = start
elif head_idx >= end: elif head_idx >= end:
self.data[i].head -= offset self.c[i].head -= offset
# Now compress the token array # Now compress the token array
for i in range(end, self.length): for i in range(end, self.length):
self.data[i - offset] = self.data[i] self.c[i - offset] = self.c[i]
for i in range(self.length - offset, self.length): for i in range(self.length - offset, self.length):
memset(&self.data[i], 0, sizeof(TokenC)) memset(&self.c[i], 0, sizeof(TokenC))
self.data[i].lex = &EMPTY_LEXEME self.c[i].lex = &EMPTY_LEXEME
self.length -= offset self.length -= offset
for i in range(self.length): for i in range(self.length):
# ...And, set heads back to a relative position # ...And, set heads back to a relative position
self.data[i].head -= i self.c[i].head -= i
# Set the left/right children, left/right edges # Set the left/right children, left/right edges
set_children_from_heads(self.data, self.length) set_children_from_heads(self.c, self.length)
# Clear the cached Python objects # Clear the cached Python objects
self._py_tokens = [None] * self.length self._py_tokens = [None] * self.length
# Return the merged Python object # Return the merged Python object
@ -569,3 +569,9 @@ cdef int set_children_from_heads(TokenC* tokens, int length) except -1:
if child.r_edge > head.r_edge: if child.r_edge > head.r_edge:
head.r_edge = child.r_edge head.r_edge = child.r_edge
head.r_kids += 1 head.r_kids += 1
# Set sentence starts
for i in range(length):
if tokens[i].head == 0 and tokens[i].dep != 0:
tokens[tokens[i].l_edge].sent_start = True

View File

@ -177,12 +177,12 @@ cdef class Span:
def __get__(self): def __get__(self):
# This should probably be called 'head', and the other one called # This should probably be called 'head', and the other one called
# 'gov'. But we went with 'head' elsehwhere, and now we're stuck =/ # 'gov'. But we went with 'head' elsehwhere, and now we're stuck =/
cdef const TokenC* start = &self.doc.data[self.start] cdef const TokenC* start = &self.doc.c[self.start]
cdef const TokenC* end = &self.doc.data[self.end] cdef const TokenC* end = &self.doc.c[self.end]
head = start head = start
while start <= (head + head.head) < end and head.head != 0: while start <= (head + head.head) < end and head.head != 0:
head += head.head head += head.head
return self.doc[head - self.doc.data] return self.doc[head - self.doc.c]
property lefts: property lefts:
"""Tokens that are to the left of the Span, whose head is within the Span.""" """Tokens that are to the left of the Span, whose head is within the Span."""

View File

@ -31,7 +31,7 @@ cdef class Token:
def __cinit__(self, Vocab vocab, Doc doc, int offset): def __cinit__(self, Vocab vocab, Doc doc, int offset):
self.vocab = vocab self.vocab = vocab
self.doc = doc self.doc = doc
self.c = &self.doc.data[offset] self.c = &self.doc.c[offset]
self.i = offset self.i = offset
self.array_len = doc.length self.array_len = doc.length
@ -143,7 +143,7 @@ cdef class Token:
def __get__(self): def __get__(self):
cdef int i cdef int i
for i in range(self.vocab.vectors_length): for i in range(self.vocab.vectors_length):
if self.c.lex.repvec[i] != 0: if self.c.lex.vector[i] != 0:
return True return True
else: else:
return False return False
@ -158,8 +158,8 @@ cdef class Token:
"\npython -m spacy.en.download all\n" "\npython -m spacy.en.download all\n"
"to install the data." "to install the data."
) )
repvec_view = <float[:length,]>self.c.lex.repvec vector_view = <float[:length,]>self.c.lex.vector
return numpy.asarray(repvec_view) return numpy.asarray(vector_view)
property repvec: property repvec:
def __get__(self): def __get__(self):
@ -259,14 +259,11 @@ cdef class Token:
def __get__(self): def __get__(self):
"""Get a list of conjoined words.""" """Get a list of conjoined words."""
cdef Token word cdef Token word
conjuncts = []
if self.dep_ != 'conj': if self.dep_ != 'conj':
for word in self.rights: for word in self.rights:
if word.dep_ == 'conj': if word.dep_ == 'conj':
yield word yield word
yield from word.conjuncts yield from word.conjuncts
conjuncts.append(word)
conjuncts.extend(word.conjuncts)
property ent_type: property ent_type:
def __get__(self): def __get__(self):

View File

@ -40,7 +40,7 @@ DEF MAX_VEC_SIZE = 100000
cdef float[MAX_VEC_SIZE] EMPTY_VEC cdef float[MAX_VEC_SIZE] EMPTY_VEC
memset(EMPTY_VEC, 0, sizeof(EMPTY_VEC)) memset(EMPTY_VEC, 0, sizeof(EMPTY_VEC))
memset(&EMPTY_LEXEME, 0, sizeof(LexemeC)) memset(&EMPTY_LEXEME, 0, sizeof(LexemeC))
EMPTY_LEXEME.repvec = EMPTY_VEC EMPTY_LEXEME.vector = EMPTY_VEC
cdef class Vocab: cdef class Vocab:
@ -62,9 +62,11 @@ cdef class Vocab:
cdef Vocab self = cls(get_lex_attr=get_lex_attr, tag_map=tag_map, cdef Vocab self = cls(get_lex_attr=get_lex_attr, tag_map=tag_map,
lemmatizer=lemmatizer, serializer_freqs=serializer_freqs) lemmatizer=lemmatizer, serializer_freqs=serializer_freqs)
with io.open(path.join(data_dir, 'strings.json'), 'r', encoding='utf8') as file_: if path.exists(path.join(data_dir, 'strings.json')):
self.strings.load(file_) with io.open(path.join(data_dir, 'strings.json'), 'r', encoding='utf8') as file_:
self.load_lexemes(path.join(data_dir, 'lexemes.bin')) self.strings.load(file_)
self.load_lexemes(path.join(data_dir, 'lexemes.bin'))
if path.exists(path.join(data_dir, 'vec.bin')): if path.exists(path.join(data_dir, 'vec.bin')):
self.vectors_length = self.load_vectors_from_bin_loc(path.join(data_dir, 'vec.bin')) self.vectors_length = self.load_vectors_from_bin_loc(path.join(data_dir, 'vec.bin'))
return self return self
@ -160,7 +162,7 @@ cdef class Vocab:
lex.orth = self.strings[string] lex.orth = self.strings[string]
lex.length = len(string) lex.length = len(string)
lex.id = self.length lex.id = self.length
lex.repvec = <float*>mem.alloc(self.vectors_length, sizeof(float)) lex.vector = <float*>mem.alloc(self.vectors_length, sizeof(float))
if self.get_lex_attr is not None: if self.get_lex_attr is not None:
for attr, func in self.get_lex_attr.items(): for attr, func in self.get_lex_attr.items():
value = func(string) value = func(string)
@ -285,7 +287,7 @@ cdef class Vocab:
fp.read_into(&lexeme.sentiment, 1, sizeof(lexeme.sentiment)) fp.read_into(&lexeme.sentiment, 1, sizeof(lexeme.sentiment))
fp.read_into(&lexeme.l2_norm, 1, sizeof(lexeme.l2_norm)) fp.read_into(&lexeme.l2_norm, 1, sizeof(lexeme.l2_norm))
lexeme.repvec = EMPTY_VEC lexeme.vector = EMPTY_VEC
py_str = self.strings[lexeme.orth] py_str = self.strings[lexeme.orth]
key = hash_string(py_str) key = hash_string(py_str)
self._by_hash.set(key, lexeme) self._by_hash.set(key, lexeme)
@ -304,7 +306,7 @@ cdef class Vocab:
cdef CFile out_file = CFile(out_loc, 'wb') cdef CFile out_file = CFile(out_loc, 'wb')
for lexeme in self: for lexeme in self:
word_str = lexeme.orth_.encode('utf8') word_str = lexeme.orth_.encode('utf8')
vec = lexeme.c.repvec vec = lexeme.c.vector
word_len = len(word_str) word_len = len(word_str)
out_file.write_from(&word_len, 1, sizeof(word_len)) out_file.write_from(&word_len, 1, sizeof(word_len))
@ -329,10 +331,10 @@ cdef class Vocab:
vec_len, len(pieces)) vec_len, len(pieces))
orth = self.strings[word_str] orth = self.strings[word_str]
lexeme = <LexemeC*><void*>self.get_by_orth(self.mem, orth) lexeme = <LexemeC*><void*>self.get_by_orth(self.mem, orth)
lexeme.repvec = <float*>self.mem.alloc(self.vectors_length, sizeof(float)) lexeme.vector = <float*>self.mem.alloc(self.vectors_length, sizeof(float))
for i, val_str in enumerate(pieces): for i, val_str in enumerate(pieces):
lexeme.repvec[i] = float(val_str) lexeme.vector[i] = float(val_str)
return vec_len return vec_len
def load_vectors_from_bin_loc(self, loc): def load_vectors_from_bin_loc(self, loc):
@ -374,12 +376,12 @@ cdef class Vocab:
for orth, lex_addr in self._by_orth.items(): for orth, lex_addr in self._by_orth.items():
lex = <LexemeC*>lex_addr lex = <LexemeC*>lex_addr
if lex.lower < vectors.size(): if lex.lower < vectors.size():
lex.repvec = vectors[lex.lower] lex.vector = vectors[lex.lower]
for i in range(vec_len): for i in range(vec_len):
lex.l2_norm += (lex.repvec[i] * lex.repvec[i]) lex.l2_norm += (lex.vector[i] * lex.vector[i])
lex.l2_norm = math.sqrt(lex.l2_norm) lex.l2_norm = math.sqrt(lex.l2_norm)
else: else:
lex.repvec = EMPTY_VEC lex.vector = EMPTY_VEC
return vec_len return vec_len

View File

@ -80,10 +80,10 @@ include ./meta.jade
| |
| def match_tweet(spacy, text, query): | def match_tweet(spacy, text, query):
| def get_vector(word): | def get_vector(word):
| return spacy.vocab[word].repvec | return spacy.vocab[word].vector
| |
| tweet = spacy(text) | tweet = spacy(text)
| tweet = [w.repvec for w in tweet if w.is_alpha and w.lower_ != query] | tweet = [w.vector for w in tweet if w.is_alpha and w.lower_ != query]
| if tweet: | if tweet:
| accept = map(get_vector, 'Jeb Cheney Republican 9/11h'.split()) | accept = map(get_vector, 'Jeb Cheney Republican 9/11h'.split())
| reject = map(get_vector, 'garden Reggie hairy'.split()) | reject = map(get_vector, 'garden Reggie hairy'.split())
@ -147,11 +147,11 @@ include ./meta.jade
pre.language-python: code pre.language-python: code
| def handle_tweet(spacy, resp, query): | def handle_tweet(spacy, resp, query):
| def get_vector(word): | def get_vector(word):
| return spacy.vocab[word].repvec | return spacy.vocab[word].vector
| |
| text = resp.get('text', '').decode('utf8') | text = resp.get('text', '').decode('utf8')
| tweet = spacy(text) | tweet = spacy(text)
| tweet = [w.repvec for w in tweet if w.is_alpha and w.lower_ != query] | tweet = [w.vector for w in tweet if w.is_alpha and w.lower_ != query]
| if tweet: | if tweet:
| accept = map(get_vector, 'Jeb Cheney Republican 9/11h'.split()) | accept = map(get_vector, 'Jeb Cheney Republican 9/11h'.split())
| reject = map(get_vector, 'garden Reggie hairy'.split()) | reject = map(get_vector, 'garden Reggie hairy'.split())