mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-07 21:54:54 +03:00
pull changes
This commit is contained in:
commit
cb376eadd3
7
fabfile.py
vendored
7
fabfile.py
vendored
|
@ -54,10 +54,10 @@ def prebuild(build_dir='/tmp/build_spacy'):
|
|||
local('pip install --no-cache-dir -r requirements.txt')
|
||||
local('fab clean make')
|
||||
local('cp -r %s/corpora/en/wordnet corpora/en/' % spacy_dir)
|
||||
local('cp %s/corpora/en/freqs.txt.gz corpora/en/' % spacy_dir)
|
||||
local('PYTHONPATH=`pwd` python bin/init_model.py en lang_data corpora spacy/en/data')
|
||||
local('fab test')
|
||||
local('python setup.py sdist')
|
||||
local('PYTHONPATH=`pwd` python -m spacy.en.download --force all')
|
||||
local('py.test --models spacy/tests/')
|
||||
|
||||
|
||||
def docs():
|
||||
|
@ -121,9 +121,8 @@ def clean():
|
|||
|
||||
def test():
|
||||
with virtualenv(VENV_DIR):
|
||||
# Run each test file separately. pytest is performing poorly, not sure why
|
||||
with lcd(path.dirname(__file__)):
|
||||
local('py.test -x tests/')
|
||||
local('py.test -x spacy/tests')
|
||||
|
||||
|
||||
def train(json_dir=None, dev_loc=None, model_dir=None):
|
||||
|
|
46
setup.py
46
setup.py
|
@ -13,6 +13,21 @@ from distutils.command.build_ext import build_ext
|
|||
|
||||
import platform
|
||||
|
||||
PACKAGE_DATA = {
|
||||
"spacy": ["*.pxd"],
|
||||
"spacy.tokens": ["*.pxd"],
|
||||
"spacy.serialize": ["*.pxd"],
|
||||
"spacy.syntax": ["*.pxd"],
|
||||
"spacy.en": [
|
||||
"*.pxd",
|
||||
"data/wordnet/*.exc",
|
||||
"data/wordnet/index.*",
|
||||
"data/tokenizer/*",
|
||||
"data/vocab/serializer.json"
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
# By subclassing build_extensions we have the actual compiler that will be used which is really known only after finalize_options
|
||||
# http://stackoverflow.com/questions/724664/python-distutils-how-to-get-a-compiler-that-is-going-to-be-used
|
||||
compile_options = {'msvc' : ['/Ox', '/EHsc'] ,
|
||||
|
@ -81,6 +96,8 @@ except OSError:
|
|||
pass
|
||||
|
||||
|
||||
|
||||
|
||||
def clean(mod_names):
|
||||
for name in mod_names:
|
||||
name = name.replace('.', '/')
|
||||
|
@ -128,15 +145,7 @@ def cython_setup(mod_names, language, includes):
|
|||
author_email='honnibal@gmail.com',
|
||||
version=VERSION,
|
||||
url="http://honnibal.github.io/spaCy/",
|
||||
package_data={"spacy": ["*.pxd", "tests/*.py", "tests/*/*.py"],
|
||||
"spacy.tokens": ["*.pxd"],
|
||||
"spacy.serialize": ["*.pxd"],
|
||||
"spacy.en": ["*.pxd", "data/pos/*",
|
||||
"data/wordnet/*", "data/tokenizer/*",
|
||||
"data/vocab/tag_map.json",
|
||||
"data/vocab/lexemes.bin",
|
||||
"data/vocab/strings.json"],
|
||||
"spacy.syntax": ["*.pxd"]},
|
||||
package_data=PACKAGE_DATA,
|
||||
ext_modules=exts,
|
||||
cmdclass={'build_ext': build_ext_cython_subclass},
|
||||
license="MIT",
|
||||
|
@ -154,7 +163,7 @@ def run_setup(exts):
|
|||
'spacy.tests.munge',
|
||||
'spacy.tests.parser',
|
||||
'spacy.tests.serialize',
|
||||
'spacy.tests.spans',
|
||||
'spacy.tests.span',
|
||||
'spacy.tests.tagger',
|
||||
'spacy.tests.tokenizer',
|
||||
'spacy.tests.tokens',
|
||||
|
@ -165,18 +174,11 @@ def run_setup(exts):
|
|||
author_email='honnibal@gmail.com',
|
||||
version=VERSION,
|
||||
url="http://honnibal.github.io/spaCy/",
|
||||
package_data={"spacy": ["*.pxd"],
|
||||
"spacy.en": ["*.pxd", "data/pos/*",
|
||||
"data/wordnet/*", "data/tokenizer/*",
|
||||
"data/vocab/lexemes.bin",
|
||||
"data/vocab/serializer.json",
|
||||
"data/vocab/oov_prob",
|
||||
"data/vocab/strings.txt"],
|
||||
"spacy.syntax": ["*.pxd"]},
|
||||
package_data=PACKAGE_DATA,
|
||||
ext_modules=exts,
|
||||
license="MIT",
|
||||
install_requires=['numpy', 'murmurhash', 'cymem >= 1.30', 'preshed >= 0.43',
|
||||
'thinc >= 3.4.1', "text_unidecode", 'plac', 'six',
|
||||
install_requires=['numpy', 'murmurhash', 'cymem == 1.30', 'preshed == 0.43',
|
||||
'thinc == 3.4.1', "text_unidecode", 'plac', 'six',
|
||||
'ujson', 'cloudpickle'],
|
||||
setup_requires=["headers_workaround"],
|
||||
cmdclass = {'build_ext': build_ext_subclass },
|
||||
|
@ -189,7 +191,7 @@ def run_setup(exts):
|
|||
headers_workaround.install_headers('numpy')
|
||||
|
||||
|
||||
VERSION = '0.97'
|
||||
VERSION = '0.99'
|
||||
def main(modules, is_pypy):
|
||||
language = "cpp"
|
||||
includes = ['.', path.join(sys.prefix, 'include')]
|
||||
|
@ -215,7 +217,7 @@ MOD_NAMES = ['spacy.parts_of_speech', 'spacy.strings',
|
|||
'spacy.syntax.arc_eager',
|
||||
'spacy.syntax._parse_features',
|
||||
'spacy.gold', 'spacy.orth',
|
||||
'spacy.tokens.doc', 'spacy.tokens.spans', 'spacy.tokens.token',
|
||||
'spacy.tokens.doc', 'spacy.tokens.span', 'spacy.tokens.token',
|
||||
'spacy.serialize.packer', 'spacy.serialize.huffman', 'spacy.serialize.bits',
|
||||
'spacy.cfile', 'spacy.matcher',
|
||||
'spacy.syntax.ner',
|
||||
|
|
|
@ -38,6 +38,7 @@ def install_data(url, extract_path, download_path):
|
|||
assert tmp == download_path
|
||||
t = tarfile.open(download_path)
|
||||
t.extractall(extract_path)
|
||||
os.unlink(download_path)
|
||||
|
||||
|
||||
@plac.annotations(
|
||||
|
|
|
@ -88,7 +88,7 @@ class Language(object):
|
|||
return orth.like_url(string)
|
||||
|
||||
@staticmethod
|
||||
def like_number(string):
|
||||
def like_num(string):
|
||||
return orth.like_number(string)
|
||||
|
||||
@staticmethod
|
||||
|
@ -119,7 +119,7 @@ class Language(object):
|
|||
attrs.IS_TITLE: cls.is_title,
|
||||
attrs.IS_UPPER: cls.is_upper,
|
||||
attrs.LIKE_URL: cls.like_url,
|
||||
attrs.LIKE_NUM: cls.like_number,
|
||||
attrs.LIKE_NUM: cls.like_num,
|
||||
attrs.LIKE_EMAIL: cls.like_email,
|
||||
attrs.IS_STOP: cls.is_stop,
|
||||
attrs.IS_OOV: lambda string: True
|
||||
|
|
|
@ -51,7 +51,7 @@ cdef class Lexeme:
|
|||
def __get__(self):
|
||||
cdef int i
|
||||
for i in range(self.vocab.vectors_length):
|
||||
if self.c.repvec[i] != 0:
|
||||
if self.c.vector[i] != 0:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
@ -74,14 +74,14 @@ cdef class Lexeme:
|
|||
"to install the data."
|
||||
)
|
||||
|
||||
repvec_view = <float[:length,]>self.c.repvec
|
||||
return numpy.asarray(repvec_view)
|
||||
vector_view = <float[:length,]>self.c.vector
|
||||
return numpy.asarray(vector_view)
|
||||
|
||||
def __set__(self, vector):
|
||||
assert len(vector) == self.vocab.vectors_length
|
||||
cdef float value
|
||||
for i, value in enumerate(vector):
|
||||
self.c.repvec[i] = value
|
||||
self.c.vector[i] = value
|
||||
|
||||
property repvec:
|
||||
def __get__(self):
|
||||
|
|
|
@ -215,7 +215,7 @@ cdef class Matcher:
|
|||
cdef Pattern* state
|
||||
matches = []
|
||||
for token_i in range(doc.length):
|
||||
token = &doc.data[token_i]
|
||||
token = &doc.c[token_i]
|
||||
q = 0
|
||||
# Go over the open matches, extending or finalizing if able. Otherwise,
|
||||
# we over-write them (q doesn't advance)
|
||||
|
@ -286,7 +286,7 @@ cdef class PhraseMatcher:
|
|||
for i in range(self.max_length):
|
||||
self._phrase_key[i] = 0
|
||||
for i, tag in enumerate(tags):
|
||||
lexeme = self.vocab[tokens.data[i].lex.orth]
|
||||
lexeme = self.vocab[tokens.c[i].lex.orth]
|
||||
lexeme.set_flag(tag, True)
|
||||
self._phrase_key[i] = lexeme.orth
|
||||
cdef hash_t key = hash64(self._phrase_key, self.max_length * sizeof(attr_t), 0)
|
||||
|
@ -309,7 +309,7 @@ cdef class PhraseMatcher:
|
|||
for i in range(self.max_length):
|
||||
self._phrase_key[i] = 0
|
||||
for i, j in enumerate(range(start, end)):
|
||||
self._phrase_key[i] = doc.data[j].lex.orth
|
||||
self._phrase_key[i] = doc.c[j].lex.orth
|
||||
cdef hash_t key = hash64(self._phrase_key, self.max_length * sizeof(attr_t), 0)
|
||||
if self.phrase_ids.get(key):
|
||||
return True
|
||||
|
|
|
@ -38,6 +38,8 @@ cdef class Morphology:
|
|||
tag_id = self.reverse_index[self.strings[tag]]
|
||||
else:
|
||||
tag_id = tag
|
||||
if tag_id >= self.n_tags:
|
||||
raise ValueError("Unknown tag: %s" % tag)
|
||||
analysis = <MorphAnalysisC*>self._cache.get(tag_id, token.lex.orth)
|
||||
if analysis is NULL:
|
||||
analysis = <MorphAnalysisC*>self.mem.alloc(1, sizeof(MorphAnalysisC))
|
||||
|
@ -86,6 +88,8 @@ cdef class Morphology:
|
|||
return orth
|
||||
cdef unicode py_string = self.strings[orth]
|
||||
if pos != NOUN and pos != VERB and pos != ADJ and pos != PUNCT:
|
||||
# TODO: This should lower-case
|
||||
# return self.strings[py_string.lower()]
|
||||
return orth
|
||||
cdef set lemma_strings
|
||||
cdef unicode lemma_string
|
||||
|
|
|
@ -155,10 +155,10 @@ cdef class Packer:
|
|||
self.char_codec.encode(bytearray(utf8_str), bits)
|
||||
cdef int i, j
|
||||
for i in range(doc.length):
|
||||
for j in range(doc.data[i].lex.length-1):
|
||||
for j in range(doc.c[i].lex.length-1):
|
||||
bits.append(False)
|
||||
bits.append(True)
|
||||
if doc.data[i].spacy:
|
||||
if doc.c[i].spacy:
|
||||
bits.append(False)
|
||||
return bits
|
||||
|
||||
|
|
|
@ -17,10 +17,7 @@ try:
|
|||
except ImportError:
|
||||
import io
|
||||
|
||||
try:
|
||||
import ujson as json
|
||||
except ImportError:
|
||||
import json
|
||||
import ujson as json
|
||||
|
||||
|
||||
cpdef hash_t hash_string(unicode string) except 0:
|
||||
|
|
|
@ -5,7 +5,7 @@ from .parts_of_speech cimport univ_pos_t
|
|||
|
||||
|
||||
cdef struct LexemeC:
|
||||
float* repvec
|
||||
float* vector
|
||||
|
||||
flags_t flags
|
||||
|
||||
|
@ -32,18 +32,8 @@ cdef struct Entity:
|
|||
int label
|
||||
|
||||
|
||||
cdef struct Constituent:
|
||||
const TokenC* head
|
||||
const Constituent* parent
|
||||
const Constituent* first
|
||||
const Constituent* last
|
||||
int label
|
||||
int length
|
||||
|
||||
|
||||
cdef struct TokenC:
|
||||
const LexemeC* lex
|
||||
const Constituent* ctnt
|
||||
uint64_t morph
|
||||
univ_pos_t pos
|
||||
bint spacy
|
||||
|
|
|
@ -84,7 +84,7 @@ cdef class Parser:
|
|||
return cls(strings, moves, model)
|
||||
|
||||
def __call__(self, Doc tokens):
|
||||
cdef StateClass stcls = StateClass.init(tokens.data, tokens.length)
|
||||
cdef StateClass stcls = StateClass.init(tokens.c, tokens.length)
|
||||
self.moves.initialize_state(stcls)
|
||||
|
||||
cdef Example eg = Example(self.model.n_classes, CONTEXT_SIZE,
|
||||
|
@ -112,7 +112,7 @@ cdef class Parser:
|
|||
|
||||
def train(self, Doc tokens, GoldParse gold):
|
||||
self.moves.preprocess_gold(gold)
|
||||
cdef StateClass stcls = StateClass.init(tokens.data, tokens.length)
|
||||
cdef StateClass stcls = StateClass.init(tokens.c, tokens.length)
|
||||
self.moves.initialize_state(stcls)
|
||||
cdef Example eg = Example(self.model.n_classes, CONTEXT_SIZE,
|
||||
self.model.n_feats, self.model.n_feats)
|
||||
|
@ -143,7 +143,7 @@ cdef class StepwiseState:
|
|||
def __init__(self, Parser parser, Doc doc):
|
||||
self.parser = parser
|
||||
self.doc = doc
|
||||
self.stcls = StateClass.init(doc.data, doc.length)
|
||||
self.stcls = StateClass.init(doc.c, doc.length)
|
||||
self.parser.moves.initialize_state(self.stcls)
|
||||
self.eg = Example(self.parser.model.n_classes, CONTEXT_SIZE,
|
||||
self.parser.model.n_feats, self.parser.model.n_feats)
|
||||
|
|
|
@ -141,9 +141,9 @@ cdef class Tagger:
|
|||
cdef int i
|
||||
cdef const weight_t* scores
|
||||
for i in range(tokens.length):
|
||||
if tokens.data[i].pos == 0:
|
||||
guess = self.predict(i, tokens.data)
|
||||
self.vocab.morphology.assign_tag(&tokens.data[i], guess)
|
||||
if tokens.c[i].pos == 0:
|
||||
guess = self.predict(i, tokens.c)
|
||||
self.vocab.morphology.assign_tag(&tokens.c[i], guess)
|
||||
|
||||
tokens.is_tagged = True
|
||||
tokens._py_tokens = [None] * tokens.length
|
||||
|
@ -154,7 +154,7 @@ cdef class Tagger:
|
|||
def tag_from_strings(self, Doc tokens, object tag_strs):
|
||||
cdef int i
|
||||
for i in range(tokens.length):
|
||||
self.vocab.morphology.assign_tag(&tokens.data[i], tag_strs[i])
|
||||
self.vocab.morphology.assign_tag(&tokens.c[i], tag_strs[i])
|
||||
tokens.is_tagged = True
|
||||
tokens._py_tokens = [None] * tokens.length
|
||||
|
||||
|
@ -170,13 +170,13 @@ cdef class Tagger:
|
|||
[g for g in gold_tag_strs if g is not None and g not in self.tag_names])
|
||||
correct = 0
|
||||
for i in range(tokens.length):
|
||||
guess = self.update(i, tokens.data, golds[i])
|
||||
guess = self.update(i, tokens.c, golds[i])
|
||||
loss = golds[i] != -1 and guess != golds[i]
|
||||
|
||||
self.vocab.morphology.assign_tag(&tokens.data[i], guess)
|
||||
self.vocab.morphology.assign_tag(&tokens.c[i], guess)
|
||||
|
||||
correct += loss == 0
|
||||
self.freqs[TAG][tokens.data[i].tag] += 1
|
||||
self.freqs[TAG][tokens.c[i].tag] += 1
|
||||
return correct
|
||||
|
||||
cdef int predict(self, int i, const TokenC* tokens) except -1:
|
||||
|
|
|
@ -38,3 +38,14 @@ def test_left_right(EN):
|
|||
for child in word.rights:
|
||||
assert child.head.i == word.i
|
||||
|
||||
|
||||
@pytest.mark.models
|
||||
def test_lemmas(EN):
|
||||
orig = EN(u'The geese are flying')
|
||||
result = Doc(orig.vocab).from_bytes(orig.to_bytes())
|
||||
the, geese, are, flying = result
|
||||
assert geese.lemma_ == 'goose'
|
||||
assert are.lemma_ == 'be'
|
||||
assert flying.lemma_ == 'fly'
|
||||
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@ import pytest
|
|||
import numpy
|
||||
|
||||
from spacy.language import Language
|
||||
from spacy.en import English
|
||||
from spacy.vocab import Vocab
|
||||
from spacy.tokens.doc import Doc
|
||||
from spacy.tokenizer import Tokenizer
|
||||
|
@ -20,7 +21,7 @@ from spacy.serialize.bits import BitArray
|
|||
|
||||
@pytest.fixture
|
||||
def vocab():
|
||||
vocab = Vocab(Language.default_lex_attrs())
|
||||
vocab = English.default_vocab()
|
||||
lex = vocab['dog']
|
||||
assert vocab[vocab.strings['dog']].orth_ == 'dog'
|
||||
lex = vocab['the']
|
||||
|
@ -64,6 +65,7 @@ def test_packer_unannotated(tokenizer):
|
|||
assert result.string == 'the dog jumped'
|
||||
|
||||
|
||||
@pytest.mark.models
|
||||
def test_packer_annotated(tokenizer):
|
||||
vocab = tokenizer.vocab
|
||||
nn = vocab.strings['NN']
|
||||
|
|
|
@ -5,6 +5,11 @@ import pickle
|
|||
import pytest
|
||||
import tempfile
|
||||
|
||||
try:
|
||||
unicode
|
||||
except NameError:
|
||||
unicode = str
|
||||
|
||||
@pytest.mark.models
|
||||
def test_pickle_english(EN):
|
||||
file_ = io.BytesIO()
|
||||
|
@ -21,7 +26,7 @@ def test_cloudpickle_to_file(EN):
|
|||
p = cloudpickle.CloudPickler(f)
|
||||
p.dump(EN)
|
||||
f.close()
|
||||
loaded_en = cloudpickle.load(open(f.name))
|
||||
loaded_en = cloudpickle.load(open(f.name, 'rb'))
|
||||
os.unlink(f.name)
|
||||
doc = loaded_en(unicode('test parse'))
|
||||
assert len(doc) == 2
|
||||
|
|
|
@ -113,7 +113,7 @@ cdef class Tokenizer:
|
|||
self._tokenize(tokens, span, key)
|
||||
in_ws = not in_ws
|
||||
if uc == ' ':
|
||||
tokens.data[tokens.length - 1].spacy = True
|
||||
tokens.c[tokens.length - 1].spacy = True
|
||||
start = i + 1
|
||||
else:
|
||||
start = i
|
||||
|
@ -125,7 +125,7 @@ cdef class Tokenizer:
|
|||
cache_hit = self._try_cache(key, tokens)
|
||||
if not cache_hit:
|
||||
self._tokenize(tokens, span, key)
|
||||
tokens.data[tokens.length - 1].spacy = string[-1] == ' '
|
||||
tokens.c[tokens.length - 1].spacy = string[-1] == ' '
|
||||
return tokens
|
||||
|
||||
cdef int _try_cache(self, hash_t key, Doc tokens) except -1:
|
||||
|
@ -148,7 +148,7 @@ cdef class Tokenizer:
|
|||
orig_size = tokens.length
|
||||
span = self._split_affixes(span, &prefixes, &suffixes)
|
||||
self._attach_tokens(tokens, span, &prefixes, &suffixes)
|
||||
self._save_cached(&tokens.data[orig_size], orig_key, tokens.length - orig_size)
|
||||
self._save_cached(&tokens.c[orig_size], orig_key, tokens.length - orig_size)
|
||||
|
||||
cdef unicode _split_affixes(self, unicode string, vector[const LexemeC*] *prefixes,
|
||||
vector[const LexemeC*] *suffixes):
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from .doc import Doc
|
||||
from .token import Token
|
||||
from .spans import Span
|
||||
from .span import Span
|
||||
|
||||
__all__ = [Doc, Token, Span]
|
||||
|
|
|
@ -26,7 +26,7 @@ cdef class Doc:
|
|||
cdef public object _vector
|
||||
cdef public object _vector_norm
|
||||
|
||||
cdef TokenC* data
|
||||
cdef TokenC* c
|
||||
|
||||
cdef public bint is_tagged
|
||||
cdef public bint is_parsed
|
||||
|
|
|
@ -18,7 +18,7 @@ from ..attrs cimport POS, LEMMA, TAG, DEP, HEAD, SPACY, ENT_IOB, ENT_TYPE
|
|||
from ..parts_of_speech cimport CONJ, PUNCT, NOUN
|
||||
from ..parts_of_speech cimport univ_pos_t
|
||||
from ..lexeme cimport Lexeme
|
||||
from .spans cimport Span
|
||||
from .span cimport Span
|
||||
from .token cimport Token
|
||||
from ..serialize.bits cimport BitArray
|
||||
from ..util import normalize_slice
|
||||
|
@ -73,7 +73,7 @@ cdef class Doc:
|
|||
data_start[i].lex = &EMPTY_LEXEME
|
||||
data_start[i].l_edge = i
|
||||
data_start[i].r_edge = i
|
||||
self.data = data_start + PADDING
|
||||
self.c = data_start + PADDING
|
||||
self.max_length = size
|
||||
self.length = 0
|
||||
self.is_tagged = False
|
||||
|
@ -97,7 +97,7 @@ cdef class Doc:
|
|||
if self._py_tokens[i] is not None:
|
||||
return self._py_tokens[i]
|
||||
else:
|
||||
return Token.cinit(self.vocab, &self.data[i], i, self)
|
||||
return Token.cinit(self.vocab, &self.c[i], i, self)
|
||||
|
||||
def __iter__(self):
|
||||
"""Iterate over the tokens.
|
||||
|
@ -110,7 +110,7 @@ cdef class Doc:
|
|||
if self._py_tokens[i] is not None:
|
||||
yield self._py_tokens[i]
|
||||
else:
|
||||
yield Token.cinit(self.vocab, &self.data[i], i, self)
|
||||
yield Token.cinit(self.vocab, &self.c[i], i, self)
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
@ -134,10 +134,6 @@ cdef class Doc:
|
|||
return 0.0
|
||||
return numpy.dot(self.vector, other.vector) / (self.vector_norm * other.vector_norm)
|
||||
|
||||
property repvec:
|
||||
def __get__(self):
|
||||
return self.vector
|
||||
|
||||
property vector:
|
||||
def __get__(self):
|
||||
if self._vector is None:
|
||||
|
@ -191,7 +187,7 @@ cdef class Doc:
|
|||
cdef int label = 0
|
||||
output = []
|
||||
for i in range(self.length):
|
||||
token = &self.data[i]
|
||||
token = &self.c[i]
|
||||
if token.ent_iob == 1:
|
||||
assert start != -1
|
||||
elif token.ent_iob == 2 or token.ent_iob == 0:
|
||||
|
@ -216,23 +212,23 @@ cdef class Doc:
|
|||
# 4. Test more nuanced date and currency regex
|
||||
cdef int i
|
||||
for i in range(self.length):
|
||||
self.data[i].ent_type = 0
|
||||
self.data[i].ent_iob = 0
|
||||
self.c[i].ent_type = 0
|
||||
self.c[i].ent_iob = 0
|
||||
cdef attr_t ent_type
|
||||
cdef int start, end
|
||||
for ent_type, start, end in ents:
|
||||
if ent_type is None or ent_type < 0:
|
||||
# Mark as O
|
||||
for i in range(start, end):
|
||||
self.data[i].ent_type = 0
|
||||
self.data[i].ent_iob = 2
|
||||
self.c[i].ent_type = 0
|
||||
self.c[i].ent_iob = 2
|
||||
else:
|
||||
# Mark (inside) as I
|
||||
for i in range(start, end):
|
||||
self.data[i].ent_type = ent_type
|
||||
self.data[i].ent_iob = 1
|
||||
self.c[i].ent_type = ent_type
|
||||
self.c[i].ent_iob = 1
|
||||
# Set start as B
|
||||
self.data[start].ent_iob = 3
|
||||
self.c[start].ent_iob = 3
|
||||
|
||||
@property
|
||||
def noun_chunks(self):
|
||||
|
@ -249,7 +245,7 @@ cdef class Doc:
|
|||
np_deps = [self.vocab.strings[label] for label in labels]
|
||||
np_label = self.vocab.strings['NP']
|
||||
for i in range(self.length):
|
||||
word = &self.data[i]
|
||||
word = &self.c[i]
|
||||
if word.pos == NOUN and word.dep in np_deps:
|
||||
yield Span(self, word.l_edge, i+1, label=np_label)
|
||||
|
||||
|
@ -267,7 +263,7 @@ cdef class Doc:
|
|||
cdef int i
|
||||
start = 0
|
||||
for i in range(1, self.length):
|
||||
if self.data[i].sent_start:
|
||||
if self.c[i].sent_start:
|
||||
yield Span(self, start, i)
|
||||
start = i
|
||||
yield Span(self, start, self.length)
|
||||
|
@ -275,7 +271,7 @@ cdef class Doc:
|
|||
cdef int push_back(self, LexemeOrToken lex_or_tok, bint has_space) except -1:
|
||||
if self.length == self.max_length:
|
||||
self._realloc(self.length * 2)
|
||||
cdef TokenC* t = &self.data[self.length]
|
||||
cdef TokenC* t = &self.c[self.length]
|
||||
if LexemeOrToken is const_TokenC_ptr:
|
||||
t[0] = lex_or_tok[0]
|
||||
else:
|
||||
|
@ -314,7 +310,7 @@ cdef class Doc:
|
|||
output = numpy.ndarray(shape=(self.length, len(attr_ids)), dtype=numpy.int32)
|
||||
for i in range(self.length):
|
||||
for j, feature in enumerate(attr_ids):
|
||||
output[i, j] = get_token_attr(&self.data[i], feature)
|
||||
output[i, j] = get_token_attr(&self.c[i], feature)
|
||||
return output
|
||||
|
||||
def count_by(self, attr_id_t attr_id, exclude=None, PreshCounter counts=None):
|
||||
|
@ -344,11 +340,11 @@ cdef class Doc:
|
|||
# Take this check out of the loop, for a bit of extra speed
|
||||
if exclude is None:
|
||||
for i in range(self.length):
|
||||
counts.inc(get_token_attr(&self.data[i], attr_id), 1)
|
||||
counts.inc(get_token_attr(&self.c[i], attr_id), 1)
|
||||
else:
|
||||
for i in range(self.length):
|
||||
if not exclude(self[i]):
|
||||
attr = get_token_attr(&self.data[i], attr_id)
|
||||
attr = get_token_attr(&self.c[i], attr_id)
|
||||
counts.inc(attr, 1)
|
||||
if output_dict:
|
||||
return dict(counts)
|
||||
|
@ -361,12 +357,12 @@ cdef class Doc:
|
|||
# words out-of-bounds, and get out-of-bounds markers.
|
||||
# Now that we want to realloc, we need the address of the true start,
|
||||
# so we jump the pointer back PADDING places.
|
||||
cdef TokenC* data_start = self.data - PADDING
|
||||
cdef TokenC* data_start = self.c - PADDING
|
||||
data_start = <TokenC*>self.mem.realloc(data_start, n * sizeof(TokenC))
|
||||
self.data = data_start + PADDING
|
||||
self.c = data_start + PADDING
|
||||
cdef int i
|
||||
for i in range(self.length, self.max_length + PADDING):
|
||||
self.data[i].lex = &EMPTY_LEXEME
|
||||
self.c[i].lex = &EMPTY_LEXEME
|
||||
|
||||
cdef int set_parse(self, const TokenC* parsed) except -1:
|
||||
# TODO: This method is fairly misleading atm. It's used by Parser
|
||||
|
@ -375,14 +371,14 @@ cdef class Doc:
|
|||
# Probably we should use from_array?
|
||||
self.is_parsed = True
|
||||
for i in range(self.length):
|
||||
self.data[i] = parsed[i]
|
||||
assert self.data[i].l_edge <= i
|
||||
assert self.data[i].r_edge >= i
|
||||
self.c[i] = parsed[i]
|
||||
assert self.c[i].l_edge <= i
|
||||
assert self.c[i].r_edge >= i
|
||||
|
||||
def from_array(self, attrs, array):
|
||||
cdef int i, col
|
||||
cdef attr_id_t attr_id
|
||||
cdef TokenC* tokens = self.data
|
||||
cdef TokenC* tokens = self.c
|
||||
cdef int length = len(array)
|
||||
cdef attr_t[:] values
|
||||
for col, attr_id in enumerate(attrs):
|
||||
|
@ -398,7 +394,8 @@ cdef class Doc:
|
|||
self.is_parsed = True
|
||||
elif attr_id == TAG:
|
||||
for i in range(length):
|
||||
tokens[i].tag = values[i]
|
||||
self.vocab.morphology.assign_tag(&tokens[i],
|
||||
self.vocab.morphology.reverse_index[values[i]])
|
||||
if not self.is_tagged and tokens[i].tag != 0:
|
||||
self.is_tagged = True
|
||||
elif attr_id == POS:
|
||||
|
@ -413,7 +410,9 @@ cdef class Doc:
|
|||
elif attr_id == ENT_TYPE:
|
||||
for i in range(length):
|
||||
tokens[i].ent_type = values[i]
|
||||
set_children_from_heads(self.data, self.length)
|
||||
else:
|
||||
raise ValueError("Unknown attribute ID: %d" % attr_id)
|
||||
set_children_from_heads(self.c, self.length)
|
||||
return self
|
||||
|
||||
def to_bytes(self):
|
||||
|
@ -463,9 +462,9 @@ cdef class Doc:
|
|||
cdef int start = -1
|
||||
cdef int end = -1
|
||||
for i in range(self.length):
|
||||
if self.data[i].idx == start_idx:
|
||||
if self.c[i].idx == start_idx:
|
||||
start = i
|
||||
if (self.data[i].idx + self.data[i].lex.length) == end_idx:
|
||||
if (self.c[i].idx + self.c[i].lex.length) == end_idx:
|
||||
if start == -1:
|
||||
return None
|
||||
end = i + 1
|
||||
|
@ -488,10 +487,11 @@ cdef class Doc:
|
|||
new_orth = new_orth[:-len(span[-1].whitespace_)]
|
||||
cdef const LexemeC* lex = self.vocab.get(self.mem, new_orth)
|
||||
# House the new merged token where it starts
|
||||
cdef TokenC* token = &self.data[start]
|
||||
token.spacy = self.data[end-1].spacy
|
||||
# What to do about morphology??
|
||||
# TODO: token.morph = ???
|
||||
cdef TokenC* token = &self.c[start]
|
||||
token.spacy = self.c[end-1].spacy
|
||||
if tag in self.vocab.morphology.tag_map:
|
||||
self.vocab.morphology.assign_tag(token, tag)
|
||||
else:
|
||||
token.tag = self.vocab.strings[tag]
|
||||
token.lemma = self.vocab.strings[lemma]
|
||||
if ent_type == 'O':
|
||||
|
@ -511,31 +511,31 @@ cdef class Doc:
|
|||
# as it modifies the character offsets in the doc
|
||||
token.lex = lex
|
||||
for i in range(self.length):
|
||||
self.data[i].head += i
|
||||
self.c[i].head += i
|
||||
# Set the head of the merged token, and its dep relation, from the Span
|
||||
token.head = self.data[span_root].head
|
||||
token.head = self.c[span_root].head
|
||||
# Adjust deps before shrinking tokens
|
||||
# Tokens which point into the merged token should now point to it
|
||||
# Subtract the offset from all tokens which point to >= end
|
||||
offset = (end - start) - 1
|
||||
for i in range(self.length):
|
||||
head_idx = self.data[i].head
|
||||
head_idx = self.c[i].head
|
||||
if start <= head_idx < end:
|
||||
self.data[i].head = start
|
||||
self.c[i].head = start
|
||||
elif head_idx >= end:
|
||||
self.data[i].head -= offset
|
||||
self.c[i].head -= offset
|
||||
# Now compress the token array
|
||||
for i in range(end, self.length):
|
||||
self.data[i - offset] = self.data[i]
|
||||
self.c[i - offset] = self.c[i]
|
||||
for i in range(self.length - offset, self.length):
|
||||
memset(&self.data[i], 0, sizeof(TokenC))
|
||||
self.data[i].lex = &EMPTY_LEXEME
|
||||
memset(&self.c[i], 0, sizeof(TokenC))
|
||||
self.c[i].lex = &EMPTY_LEXEME
|
||||
self.length -= offset
|
||||
for i in range(self.length):
|
||||
# ...And, set heads back to a relative position
|
||||
self.data[i].head -= i
|
||||
self.c[i].head -= i
|
||||
# Set the left/right children, left/right edges
|
||||
set_children_from_heads(self.data, self.length)
|
||||
set_children_from_heads(self.c, self.length)
|
||||
# Clear the cached Python objects
|
||||
self._py_tokens = [None] * self.length
|
||||
# Return the merged Python object
|
||||
|
@ -569,3 +569,9 @@ cdef int set_children_from_heads(TokenC* tokens, int length) except -1:
|
|||
if child.r_edge > head.r_edge:
|
||||
head.r_edge = child.r_edge
|
||||
head.r_kids += 1
|
||||
|
||||
# Set sentence starts
|
||||
for i in range(length):
|
||||
if tokens[i].head == 0 and tokens[i].dep != 0:
|
||||
tokens[tokens[i].l_edge].sent_start = True
|
||||
|
||||
|
|
|
@ -177,12 +177,12 @@ cdef class Span:
|
|||
def __get__(self):
|
||||
# This should probably be called 'head', and the other one called
|
||||
# 'gov'. But we went with 'head' elsehwhere, and now we're stuck =/
|
||||
cdef const TokenC* start = &self.doc.data[self.start]
|
||||
cdef const TokenC* end = &self.doc.data[self.end]
|
||||
cdef const TokenC* start = &self.doc.c[self.start]
|
||||
cdef const TokenC* end = &self.doc.c[self.end]
|
||||
head = start
|
||||
while start <= (head + head.head) < end and head.head != 0:
|
||||
head += head.head
|
||||
return self.doc[head - self.doc.data]
|
||||
return self.doc[head - self.doc.c]
|
||||
|
||||
property lefts:
|
||||
"""Tokens that are to the left of the Span, whose head is within the Span."""
|
|
@ -31,7 +31,7 @@ cdef class Token:
|
|||
def __cinit__(self, Vocab vocab, Doc doc, int offset):
|
||||
self.vocab = vocab
|
||||
self.doc = doc
|
||||
self.c = &self.doc.data[offset]
|
||||
self.c = &self.doc.c[offset]
|
||||
self.i = offset
|
||||
self.array_len = doc.length
|
||||
|
||||
|
@ -143,7 +143,7 @@ cdef class Token:
|
|||
def __get__(self):
|
||||
cdef int i
|
||||
for i in range(self.vocab.vectors_length):
|
||||
if self.c.lex.repvec[i] != 0:
|
||||
if self.c.lex.vector[i] != 0:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
@ -158,8 +158,8 @@ cdef class Token:
|
|||
"\npython -m spacy.en.download all\n"
|
||||
"to install the data."
|
||||
)
|
||||
repvec_view = <float[:length,]>self.c.lex.repvec
|
||||
return numpy.asarray(repvec_view)
|
||||
vector_view = <float[:length,]>self.c.lex.vector
|
||||
return numpy.asarray(vector_view)
|
||||
|
||||
property repvec:
|
||||
def __get__(self):
|
||||
|
@ -259,14 +259,11 @@ cdef class Token:
|
|||
def __get__(self):
|
||||
"""Get a list of conjoined words."""
|
||||
cdef Token word
|
||||
conjuncts = []
|
||||
if self.dep_ != 'conj':
|
||||
for word in self.rights:
|
||||
if word.dep_ == 'conj':
|
||||
yield word
|
||||
yield from word.conjuncts
|
||||
conjuncts.append(word)
|
||||
conjuncts.extend(word.conjuncts)
|
||||
|
||||
property ent_type:
|
||||
def __get__(self):
|
||||
|
|
|
@ -40,7 +40,7 @@ DEF MAX_VEC_SIZE = 100000
|
|||
cdef float[MAX_VEC_SIZE] EMPTY_VEC
|
||||
memset(EMPTY_VEC, 0, sizeof(EMPTY_VEC))
|
||||
memset(&EMPTY_LEXEME, 0, sizeof(LexemeC))
|
||||
EMPTY_LEXEME.repvec = EMPTY_VEC
|
||||
EMPTY_LEXEME.vector = EMPTY_VEC
|
||||
|
||||
|
||||
cdef class Vocab:
|
||||
|
@ -62,9 +62,11 @@ cdef class Vocab:
|
|||
cdef Vocab self = cls(get_lex_attr=get_lex_attr, tag_map=tag_map,
|
||||
lemmatizer=lemmatizer, serializer_freqs=serializer_freqs)
|
||||
|
||||
if path.exists(path.join(data_dir, 'strings.json')):
|
||||
with io.open(path.join(data_dir, 'strings.json'), 'r', encoding='utf8') as file_:
|
||||
self.strings.load(file_)
|
||||
self.load_lexemes(path.join(data_dir, 'lexemes.bin'))
|
||||
|
||||
if path.exists(path.join(data_dir, 'vec.bin')):
|
||||
self.vectors_length = self.load_vectors_from_bin_loc(path.join(data_dir, 'vec.bin'))
|
||||
return self
|
||||
|
@ -160,7 +162,7 @@ cdef class Vocab:
|
|||
lex.orth = self.strings[string]
|
||||
lex.length = len(string)
|
||||
lex.id = self.length
|
||||
lex.repvec = <float*>mem.alloc(self.vectors_length, sizeof(float))
|
||||
lex.vector = <float*>mem.alloc(self.vectors_length, sizeof(float))
|
||||
if self.get_lex_attr is not None:
|
||||
for attr, func in self.get_lex_attr.items():
|
||||
value = func(string)
|
||||
|
@ -285,7 +287,7 @@ cdef class Vocab:
|
|||
fp.read_into(&lexeme.sentiment, 1, sizeof(lexeme.sentiment))
|
||||
fp.read_into(&lexeme.l2_norm, 1, sizeof(lexeme.l2_norm))
|
||||
|
||||
lexeme.repvec = EMPTY_VEC
|
||||
lexeme.vector = EMPTY_VEC
|
||||
py_str = self.strings[lexeme.orth]
|
||||
key = hash_string(py_str)
|
||||
self._by_hash.set(key, lexeme)
|
||||
|
@ -304,7 +306,7 @@ cdef class Vocab:
|
|||
cdef CFile out_file = CFile(out_loc, 'wb')
|
||||
for lexeme in self:
|
||||
word_str = lexeme.orth_.encode('utf8')
|
||||
vec = lexeme.c.repvec
|
||||
vec = lexeme.c.vector
|
||||
word_len = len(word_str)
|
||||
|
||||
out_file.write_from(&word_len, 1, sizeof(word_len))
|
||||
|
@ -329,10 +331,10 @@ cdef class Vocab:
|
|||
vec_len, len(pieces))
|
||||
orth = self.strings[word_str]
|
||||
lexeme = <LexemeC*><void*>self.get_by_orth(self.mem, orth)
|
||||
lexeme.repvec = <float*>self.mem.alloc(self.vectors_length, sizeof(float))
|
||||
lexeme.vector = <float*>self.mem.alloc(self.vectors_length, sizeof(float))
|
||||
|
||||
for i, val_str in enumerate(pieces):
|
||||
lexeme.repvec[i] = float(val_str)
|
||||
lexeme.vector[i] = float(val_str)
|
||||
return vec_len
|
||||
|
||||
def load_vectors_from_bin_loc(self, loc):
|
||||
|
@ -374,12 +376,12 @@ cdef class Vocab:
|
|||
for orth, lex_addr in self._by_orth.items():
|
||||
lex = <LexemeC*>lex_addr
|
||||
if lex.lower < vectors.size():
|
||||
lex.repvec = vectors[lex.lower]
|
||||
lex.vector = vectors[lex.lower]
|
||||
for i in range(vec_len):
|
||||
lex.l2_norm += (lex.repvec[i] * lex.repvec[i])
|
||||
lex.l2_norm += (lex.vector[i] * lex.vector[i])
|
||||
lex.l2_norm = math.sqrt(lex.l2_norm)
|
||||
else:
|
||||
lex.repvec = EMPTY_VEC
|
||||
lex.vector = EMPTY_VEC
|
||||
return vec_len
|
||||
|
||||
|
||||
|
|
|
@ -80,10 +80,10 @@ include ./meta.jade
|
|||
|
|
||||
| def match_tweet(spacy, text, query):
|
||||
| def get_vector(word):
|
||||
| return spacy.vocab[word].repvec
|
||||
| return spacy.vocab[word].vector
|
||||
|
|
||||
| tweet = spacy(text)
|
||||
| tweet = [w.repvec for w in tweet if w.is_alpha and w.lower_ != query]
|
||||
| tweet = [w.vector for w in tweet if w.is_alpha and w.lower_ != query]
|
||||
| if tweet:
|
||||
| accept = map(get_vector, 'Jeb Cheney Republican 9/11h'.split())
|
||||
| reject = map(get_vector, 'garden Reggie hairy'.split())
|
||||
|
@ -147,11 +147,11 @@ include ./meta.jade
|
|||
pre.language-python: code
|
||||
| def handle_tweet(spacy, resp, query):
|
||||
| def get_vector(word):
|
||||
| return spacy.vocab[word].repvec
|
||||
| return spacy.vocab[word].vector
|
||||
|
|
||||
| text = resp.get('text', '').decode('utf8')
|
||||
| tweet = spacy(text)
|
||||
| tweet = [w.repvec for w in tweet if w.is_alpha and w.lower_ != query]
|
||||
| tweet = [w.vector for w in tweet if w.is_alpha and w.lower_ != query]
|
||||
| if tweet:
|
||||
| accept = map(get_vector, 'Jeb Cheney Republican 9/11h'.split())
|
||||
| reject = map(get_vector, 'garden Reggie hairy'.split())
|
||||
|
|
Loading…
Reference in New Issue
Block a user