mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-10 19:57:17 +03:00
* Work on word vectors, and other stuff
This commit is contained in:
parent
7e69e17161
commit
6c7e44140b
|
@ -1,5 +1,6 @@
|
|||
from __future__ import unicode_literals
|
||||
from os import path
|
||||
import re
|
||||
|
||||
from .. import orth
|
||||
from ..vocab import Vocab
|
||||
|
@ -11,6 +12,9 @@ from .pos import POS_TAGS
|
|||
from .attrs import get_flags
|
||||
|
||||
|
||||
from ..util import read_lang_data
|
||||
|
||||
|
||||
def get_lex_props(string):
|
||||
return {
|
||||
'flags': get_flags(string),
|
||||
|
@ -64,11 +68,16 @@ class English(object):
|
|||
tag_names = list(POS_TAGS.keys())
|
||||
tag_names.sort()
|
||||
if data_dir is None:
|
||||
self.tokenizer = Tokenizer(self.vocab, {}, None, None, None,
|
||||
POS_TAGS, tag_names)
|
||||
tok_rules = {}
|
||||
prefix_re = None
|
||||
suffix_re = None
|
||||
infix_re = None
|
||||
else:
|
||||
self.tokenizer = Tokenizer.from_dir(self.vocab, path.join(data_dir, 'tokenizer'),
|
||||
POS_TAGS, tag_names)
|
||||
tok_data_dir = path.join(data_dir, 'tokenizer')
|
||||
tok_rules, prefix_re, suffix_re, infix_re = read_lang_data(tok_data_dir)
|
||||
self.tokenizer = Tokenizer(self.vocab, tok_rules, re.compile(prefix_re),
|
||||
re.compile(suffix_re), re.compile(infix_re),
|
||||
POS_TAGS, tag_names)
|
||||
self.strings = self.vocab.strings
|
||||
self._tagger = None
|
||||
self._parser = None
|
||||
|
@ -100,11 +109,11 @@ class English(object):
|
|||
Returns:
|
||||
tokens (spacy.tokens.Tokens):
|
||||
"""
|
||||
tokens = self.tokenizer.tokenize(text)
|
||||
tokens = self.tokenizer(text)
|
||||
if tag:
|
||||
self.tagger(tokens)
|
||||
if parse:
|
||||
self.parser.parse(tokens)
|
||||
self.parser(tokens)
|
||||
return tokens
|
||||
|
||||
@property
|
||||
|
|
|
@ -1,18 +1,16 @@
|
|||
from ..typedefs cimport FLAG0, FLAG1, FLAG2, FLAG3, FLAG4, FLAG5, FLAG6, FLAG7
|
||||
from ..typedefs cimport FLAG8, FLAG9
|
||||
from ..typedefs cimport ID as _ID
|
||||
from ..typedefs cimport SIC as _SIC
|
||||
from ..typedefs cimport SHAPE as _SHAPE
|
||||
from ..typedefs cimport NORM1 as _NORM1
|
||||
from ..typedefs cimport NORM2 as _NORM2
|
||||
from ..typedefs cimport CLUSTER as _CLUSTER
|
||||
from ..typedefs cimport PREFIX as _PREFIX
|
||||
from ..typedefs cimport SUFFIX as _SUFFIX
|
||||
from ..typedefs cimport LEMMA as _LEMMA
|
||||
from ..typedefs cimport POS as _POS
|
||||
from ..attrs cimport FLAG0, FLAG1, FLAG2, FLAG3, FLAG4, FLAG5, FLAG6, FLAG7
|
||||
from ..attrs cimport FLAG8, FLAG9
|
||||
from ..attrs cimport SIC as _SIC
|
||||
from ..attrs cimport SHAPE as _SHAPE
|
||||
from ..attrs cimport NORM1 as _NORM1
|
||||
from ..attrs cimport NORM2 as _NORM2
|
||||
from ..attrs cimport CLUSTER as _CLUSTER
|
||||
from ..attrs cimport PREFIX as _PREFIX
|
||||
from ..attrs cimport SUFFIX as _SUFFIX
|
||||
from ..attrs cimport LEMMA as _LEMMA
|
||||
from ..attrs cimport POS as _POS
|
||||
|
||||
|
||||
# Work around the lack of global cpdef variables
|
||||
cpdef enum:
|
||||
IS_ALPHA = FLAG0
|
||||
IS_ASCII = FLAG1
|
||||
|
@ -25,7 +23,6 @@ cpdef enum:
|
|||
LIKE_URL = FLAG8
|
||||
LIKE_NUM = FLAG9
|
||||
|
||||
ID = _ID
|
||||
SIC = _SIC
|
||||
SHAPE = _SHAPE
|
||||
NORM1 = _NORM1
|
||||
|
|
|
@ -4,33 +4,52 @@ import tarfile
|
|||
import shutil
|
||||
import requests
|
||||
|
||||
PARSER_URL = 'https://s3-us-west-1.amazonaws.com/media.spacynlp.com/en.tgz'
|
||||
PARSER_URL = 'http://s3-us-west-1.amazonaws.com/media.spacynlp.com/en.tgz'
|
||||
|
||||
DEST_DIR = path.join(path.dirname(__file__), 'data', 'deps')
|
||||
DEP_VECTORS_URL = 'http://u.cs.biu.ac.il/~yogo/data/syntemb/deps.words.bz2'
|
||||
|
||||
DEST_DIR = path.join(path.dirname(__file__), 'data')
|
||||
|
||||
def download_file(url):
|
||||
local_filename = url.split('/')[-1]
|
||||
return path.join(DEST_DIR, local_filename)
|
||||
# NOTE the stream=True parameter
|
||||
r = requests.get(url, stream=True)
|
||||
print "Download %s" % url
|
||||
i = 0
|
||||
with open(local_filename, 'wb') as f:
|
||||
for chunk in r.iter_content(chunk_size=1024):
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
f.write(chunk)
|
||||
f.flush()
|
||||
print i
|
||||
i += 1
|
||||
return local_filename
|
||||
|
||||
def main():
|
||||
if not os.path.exists(DEST_DIR):
|
||||
os.mkdir(DEST_DIR)
|
||||
assert not path.exists(path.join(DEST_DIR, 'en'))
|
||||
def install_parser_model(url, dest_dir):
|
||||
if not os.path.exists(dest_dir):
|
||||
os.mkdir(dest_dir)
|
||||
assert not path.exists(path.join(dest_dir, 'en'))
|
||||
|
||||
|
||||
filename = download_file(URL)
|
||||
filename = download_file(url)
|
||||
t = tarfile.open(filename, mode=":gz")
|
||||
t.extractall(DEST_DIR)
|
||||
shutil.move(path.join(DEST_DIR, 'en', 'deps', 'model'), DEST_DIR)
|
||||
shutil.move(path.join(DEST_DIR, 'en', 'deps', 'config.json'), DEST_DIR)
|
||||
shutil.rmtree(path.join(DEST_DIR, 'en'))
|
||||
t.extractall(dest_dir)
|
||||
shutil.move(path.join(dest_dir, 'en', 'deps', 'model'), dest_dir)
|
||||
shutil.move(path.join(dest_dir, 'en', 'deps', 'config.json'), dest_dir)
|
||||
shutil.rmtree(path.join(dest_dir, 'en'))
|
||||
|
||||
|
||||
def install_dep_vectors(url, dest_dir):
|
||||
if not os.path.exists(dest_dir):
|
||||
os.mkdir(dest_dir)
|
||||
|
||||
filename = download_file(url)
|
||||
shutil.move(filename, path.join(dest_dir, 'vec.bz2'))
|
||||
|
||||
|
||||
def main():
|
||||
#install_parser_model(PARSER_URL, path.join(DEST_DIR, 'deps'))
|
||||
install_dep_vectors(DEP_VECTORS_URL, path.join(DEST_DIR, 'vocab'))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -247,11 +247,12 @@ cdef class EnPosTagger:
|
|||
cdef atom_t[N_CONTEXT_FIELDS] context
|
||||
cdef const weight_t* scores
|
||||
for i in range(tokens.length):
|
||||
if tokens.data[i].fine_pos == 0:
|
||||
if tokens.data[i].pos == 0:
|
||||
fill_context(context, i, tokens.data)
|
||||
scores = self.model.score(context)
|
||||
tokens.data[i].fine_pos = arg_max(scores, self.model.n_classes)
|
||||
tokens.data[i].tag = arg_max(scores, self.model.n_classes)
|
||||
self.set_morph(i, tokens.data)
|
||||
tokens.pos_scheme = self.tag_map
|
||||
|
||||
def train(self, Tokens tokens, object golds):
|
||||
cdef int i
|
||||
|
@ -263,13 +264,13 @@ cdef class EnPosTagger:
|
|||
scores = self.model.score(context)
|
||||
guess = arg_max(scores, self.model.n_classes)
|
||||
self.model.update(context, guess, golds[i], guess != golds[i])
|
||||
tokens.data[i].fine_pos = guess
|
||||
tokens.data[i].tag = guess
|
||||
self.set_morph(i, tokens.data)
|
||||
correct += guess == golds[i]
|
||||
return correct
|
||||
|
||||
cdef int set_morph(self, const int i, TokenC* tokens) except -1:
|
||||
cdef const PosTag* tag = &self.tags[tokens[i].fine_pos]
|
||||
cdef const PosTag* tag = &self.tags[tokens[i].tag]
|
||||
tokens[i].pos = tag.pos
|
||||
cached = <_CachedMorph*>self._morph_cache.get(tag.id, tokens[i].lex.sic)
|
||||
if cached is NULL:
|
||||
|
|
|
@ -3,39 +3,70 @@ from .typedefs cimport ID, SIC, NORM1, NORM2, SHAPE, PREFIX, SUFFIX, LENGTH, CLU
|
|||
from .structs cimport LexemeC
|
||||
from .strings cimport StringStore
|
||||
|
||||
from numpy cimport ndarray
|
||||
|
||||
|
||||
|
||||
cdef LexemeC EMPTY_LEXEME
|
||||
|
||||
|
||||
cdef int set_lex_struct_props(LexemeC* lex, dict props, StringStore strings) except -1
|
||||
cdef int set_lex_struct_props(LexemeC* lex, dict props, StringStore strings,
|
||||
const float* empty_vec) except -1
|
||||
|
||||
cdef class Lexeme:
|
||||
cdef const float* vec
|
||||
cdef readonly ndarray vec
|
||||
|
||||
cdef readonly flags_t flags
|
||||
cdef readonly attr_t id
|
||||
cdef readonly attr_t length
|
||||
|
||||
cdef readonly attr_t sic
|
||||
cdef readonly unicode norm1
|
||||
cdef readonly unicode norm2
|
||||
cdef readonly unicode shape
|
||||
cdef readonly unicode prefix
|
||||
cdef readonly unicode suffix
|
||||
cdef readonly attr_t norm1
|
||||
cdef readonly attr_t norm2
|
||||
cdef readonly attr_t shape
|
||||
cdef readonly attr_t prefix
|
||||
cdef readonly attr_t suffix
|
||||
|
||||
cdef readonly attr_t sic_id
|
||||
cdef readonly attr_t norm1_id
|
||||
cdef readonly attr_t norm2_id
|
||||
cdef readonly attr_t shape_id
|
||||
cdef readonly attr_t prefix_id
|
||||
cdef readonly attr_t suffix_id
|
||||
cdef readonly unicode sic_
|
||||
cdef readonly unicode norm1_
|
||||
cdef readonly unicode norm2_
|
||||
cdef readonly unicode shape_
|
||||
cdef readonly unicode prefix_
|
||||
cdef readonly unicode suffix_
|
||||
|
||||
cdef readonly attr_t cluster
|
||||
cdef readonly float prob
|
||||
cdef readonly float sentiment
|
||||
|
||||
# Workaround for an apparent bug in the way the decorator is handled ---
|
||||
# TODO: post bug report / patch to Cython.
|
||||
@staticmethod
|
||||
cdef inline Lexeme from_ptr(const LexemeC* ptr, StringStore strings):
|
||||
cdef Lexeme py = Lexeme.__new__(Lexeme, 300)
|
||||
for i in range(300):
|
||||
py.vec[i] = ptr.vec[i]
|
||||
py.flags = ptr.flags
|
||||
py.id = ptr.id
|
||||
py.length = ptr.length
|
||||
|
||||
cdef Lexeme Lexeme_cinit(const LexemeC* c, StringStore strings)
|
||||
py.sic = ptr.sic
|
||||
py.norm1 = ptr.norm1
|
||||
py.norm2 = ptr.norm2
|
||||
py.shape = ptr.shape
|
||||
py.prefix = ptr.prefix
|
||||
py.suffix = ptr.suffix
|
||||
|
||||
py.sic_ = strings[ptr.sic]
|
||||
py.norm1_ = strings[ptr.norm1]
|
||||
py.norm2_ = strings[ptr.norm2]
|
||||
py.shape_ = strings[ptr.shape]
|
||||
py.prefix_ = strings[ptr.prefix]
|
||||
py.suffix_ = strings[ptr.suffix]
|
||||
|
||||
py.cluster = ptr.cluster
|
||||
py.prob = ptr.prob
|
||||
py.sentiment = ptr.sentiment
|
||||
return py
|
||||
|
||||
|
||||
cdef inline bint check_flag(const LexemeC* lexeme, attr_id_t flag_id) nogil:
|
||||
|
|
|
@ -7,13 +7,14 @@ from libc.string cimport memset
|
|||
|
||||
from .orth cimport word_shape
|
||||
from .typedefs cimport attr_t
|
||||
import numpy
|
||||
|
||||
|
||||
memset(&EMPTY_LEXEME, 0, sizeof(LexemeC))
|
||||
|
||||
|
||||
cdef int set_lex_struct_props(LexemeC* lex, dict props, StringStore string_store) except -1:
|
||||
|
||||
cdef int set_lex_struct_props(LexemeC* lex, dict props, StringStore string_store,
|
||||
const float* empty_vec) except -1:
|
||||
lex.length = props['length']
|
||||
lex.sic = string_store[props['sic']]
|
||||
lex.norm1 = string_store[props['norm1']]
|
||||
|
@ -27,39 +28,10 @@ cdef int set_lex_struct_props(LexemeC* lex, dict props, StringStore string_store
|
|||
lex.sentiment = props['sentiment']
|
||||
|
||||
lex.flags = props['flags']
|
||||
lex.vec = empty_vec
|
||||
|
||||
|
||||
cdef class Lexeme:
|
||||
"""A dummy docstring"""
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
cdef Lexeme Lexeme_cinit(const LexemeC* c, StringStore strings):
|
||||
cdef Lexeme py = Lexeme.__new__(Lexeme)
|
||||
|
||||
py.vec = c.vec
|
||||
|
||||
py.flags = c.flags
|
||||
py.id = c.id
|
||||
py.length = c.length
|
||||
|
||||
py.sic = c.sic
|
||||
py.norm1 = strings[c.norm1]
|
||||
py.norm2 = strings[c.norm2]
|
||||
py.shape = strings[c.shape]
|
||||
py.prefix = strings[c.prefix]
|
||||
py.suffix = strings[c.suffix]
|
||||
|
||||
py.sic_id = c.sic
|
||||
py.norm1_id = c.norm1
|
||||
py.norm2_id = c.norm2
|
||||
py.shape_id = c.shape
|
||||
py.prefix_id = c.prefix
|
||||
py.suffix_id = c.suffix
|
||||
|
||||
py.cluster = c.cluster
|
||||
|
||||
py.prob = c.prob
|
||||
py.sentiment = c.sentiment
|
||||
return py
|
||||
def __cinit__(self, int vec_size):
|
||||
self.vec = numpy.ndarray(shape=(vec_size,), dtype=numpy.float32)
|
||||
|
|
|
@ -137,8 +137,42 @@ cpdef unicode word_shape(unicode string):
|
|||
return ''.join(shape)
|
||||
|
||||
|
||||
cpdef unicode norm1(unicode string, lower_pc=0.0, upper_pc=0.0, title_pc=0.0):
|
||||
"""Apply level 1 normalization:
|
||||
|
||||
* Case is canonicalized, using frequency statistics
|
||||
* Unicode mapped to ascii, via unidecode
|
||||
* Regional spelling variations are normalized
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
cpdef bytes asciied(unicode string):
|
||||
cdef str stripped = unidecode(string)
|
||||
if not stripped:
|
||||
return b'???'
|
||||
return stripped.encode('ascii')
|
||||
|
||||
|
||||
# Exceptions --- do not convert these
|
||||
_uk_us_except = set([
|
||||
'our',
|
||||
'ours',
|
||||
'four',
|
||||
'fours',
|
||||
'your',
|
||||
'yours',
|
||||
'hour',
|
||||
'hours',
|
||||
'course',
|
||||
'rise',
|
||||
])
|
||||
def uk_to_usa(unicode string):
|
||||
if not string.islower():
|
||||
return string
|
||||
if string in _uk_us_except:
|
||||
return string
|
||||
our = re.compile(r'ours?$')
|
||||
string = our.sub('or', string)
|
||||
|
||||
return string
|
||||
|
|
|
@ -44,12 +44,12 @@ cdef struct TokenC:
|
|||
const LexemeC* lex
|
||||
Morphology morph
|
||||
univ_tag_t pos
|
||||
int fine_pos
|
||||
int tag
|
||||
int idx
|
||||
int lemma
|
||||
int sense
|
||||
int head
|
||||
int dep_tag
|
||||
int dep
|
||||
uint32_t l_kids
|
||||
uint32_t r_kids
|
||||
|
||||
|
|
|
@ -28,7 +28,7 @@ cdef inline void fill_token(atom_t* context, const TokenC* token) nogil:
|
|||
else:
|
||||
context[0] = token.lex.sic
|
||||
context[1] = token.lemma
|
||||
context[2] = token.fine_pos
|
||||
context[2] = token.tag
|
||||
context[3] = token.lex.cluster
|
||||
# We've read in the string little-endian, so now we can take & (2**n)-1
|
||||
# to get the first n bits of the cluster.
|
||||
|
@ -44,7 +44,7 @@ cdef inline void fill_token(atom_t* context, const TokenC* token) nogil:
|
|||
# the source that are set to 1.
|
||||
context[4] = token.lex.cluster & 63
|
||||
context[5] = token.lex.cluster & 15
|
||||
context[6] = token.dep_tag if has_head(token) else 0
|
||||
context[6] = token.dep if has_head(token) else 0
|
||||
|
||||
|
||||
cdef int fill_context(atom_t* context, State* state) except -1:
|
||||
|
|
|
@ -12,7 +12,7 @@ DEF NON_MONOTONIC = True
|
|||
cdef int add_dep(State *s, int head, int child, int label) except -1:
|
||||
cdef int dist = head - child
|
||||
s.sent[child].head = dist
|
||||
s.sent[child].dep_tag = label
|
||||
s.sent[child].dep = label
|
||||
# Keep a bit-vector tracking child dependencies. If a word has a child at
|
||||
# offset i from it, set that bit (tracking left and right separately)
|
||||
if child > head:
|
||||
|
@ -38,7 +38,7 @@ cdef int push_stack(State *s) except -1:
|
|||
if at_eol(s):
|
||||
while s.stack_len != 0:
|
||||
if not has_head(get_s0(s)):
|
||||
get_s0(s).dep_tag = 0
|
||||
get_s0(s).dep = 0
|
||||
pop_stack(s)
|
||||
|
||||
|
||||
|
|
|
@ -123,7 +123,7 @@ cdef class TransitionSystem:
|
|||
if t.move == SHIFT:
|
||||
# Set the dep label, in case we need it after we reduce
|
||||
if NON_MONOTONIC:
|
||||
get_s0(s).dep_tag = t.label
|
||||
get_s0(s).dep = t.label
|
||||
push_stack(s)
|
||||
elif t.move == LEFT:
|
||||
add_dep(s, s.i, s.stack[0], t.label)
|
||||
|
@ -132,7 +132,7 @@ cdef class TransitionSystem:
|
|||
add_dep(s, s.stack[0], s.i, t.label)
|
||||
push_stack(s)
|
||||
elif t.move == REDUCE:
|
||||
add_dep(s, s.stack[-1], s.stack[0], get_s0(s).dep_tag)
|
||||
add_dep(s, s.stack[-1], s.stack[0], get_s0(s).dep)
|
||||
pop_stack(s)
|
||||
else:
|
||||
raise Exception(t.move)
|
||||
|
|
|
@ -9,5 +9,3 @@ cdef class GreedyParser:
|
|||
cdef object cfg
|
||||
cdef readonly Model model
|
||||
cdef TransitionSystem moves
|
||||
|
||||
cpdef int parse(self, Tokens tokens) except -1
|
||||
|
|
|
@ -65,7 +65,7 @@ cdef class GreedyParser:
|
|||
hasty_templ, full_templ = get_templates(self.cfg.features)
|
||||
self.model = Model(self.moves.n_moves, full_templ, model_dir)
|
||||
|
||||
cpdef int parse(self, Tokens tokens) except -1:
|
||||
def __call__(self, Tokens tokens):
|
||||
cdef:
|
||||
Transition guess
|
||||
uint64_t state_key
|
||||
|
|
|
@ -28,8 +28,6 @@ cdef class Tokenizer:
|
|||
cdef object _infix_re
|
||||
|
||||
cpdef Tokens tokens_from_list(self, list strings)
|
||||
cpdef Tokens tokenize(self, unicode text)
|
||||
|
||||
|
||||
cdef int _try_cache(self, int idx, hash_t key, Tokens tokens) except -1
|
||||
cdef int _tokenize(self, Tokens tokens, UniStr* span, int start, int end) except -1
|
||||
|
|
|
@ -31,18 +31,6 @@ cdef class Tokenizer:
|
|||
self.vocab = vocab
|
||||
self._load_special_tokenization(rules, pos_tags, tag_names)
|
||||
|
||||
@classmethod
|
||||
def from_dir(cls, Vocab vocab, object data_dir, object pos_tags, object tag_names):
|
||||
if not path.exists(data_dir):
|
||||
raise IOError("Directory %s not found -- cannot load Tokenizer." % data_dir)
|
||||
if not path.isdir(data_dir):
|
||||
raise IOError("Path %s is a file, not a dir -- cannot load Tokenizer." % data_dir)
|
||||
|
||||
assert path.exists(data_dir) and path.isdir(data_dir)
|
||||
rules, prefix_re, suffix_re, infix_re = util.read_lang_data(data_dir)
|
||||
return cls(vocab, rules, re.compile(prefix_re), re.compile(suffix_re),
|
||||
re.compile(infix_re), pos_tags, tag_names)
|
||||
|
||||
cpdef Tokens tokens_from_list(self, list strings):
|
||||
cdef int length = sum([len(s) for s in strings])
|
||||
cdef Tokens tokens = Tokens(self.vocab, length)
|
||||
|
@ -57,7 +45,7 @@ cdef class Tokenizer:
|
|||
idx += len(py_string) + 1
|
||||
return tokens
|
||||
|
||||
cpdef Tokens tokenize(self, unicode string):
|
||||
def __call__(self, unicode string):
|
||||
"""Tokenize a string.
|
||||
|
||||
The tokenization rules are defined in three places:
|
||||
|
@ -257,7 +245,7 @@ cdef class Tokenizer:
|
|||
tokens[i].lemma = self.vocab.strings[lemma]
|
||||
if 'pos' in props:
|
||||
# TODO: Clean up this mess...
|
||||
tokens[i].fine_pos = tag_names.index(props['pos'])
|
||||
tokens[i].tag = tag_names.index(props['pos'])
|
||||
tokens[i].pos = tag_map[props['pos']][0]
|
||||
# These are defaults, which can be over-ridden by the
|
||||
# token-specific props.
|
||||
|
|
|
@ -198,9 +198,8 @@ cdef class Token:
|
|||
self.sentiment = t.lex.sentiment
|
||||
self.flags = t.lex.flags
|
||||
self.lemma = t.lemma
|
||||
self.pos = t.pos
|
||||
self.fine_pos = t.fine_pos
|
||||
self.dep_tag = t.dep_tag
|
||||
self.tag = t.tag
|
||||
self.dep = t.dep
|
||||
|
||||
def __unicode__(self):
|
||||
cdef const TokenC* t = &self._seq.data[self.i]
|
||||
|
@ -220,6 +219,12 @@ cdef class Token:
|
|||
"""
|
||||
return self._seq.data[self.i].lex.length
|
||||
|
||||
def check_flag(self, attr_id_t flag):
|
||||
return False
|
||||
|
||||
def is_pos(self, univ_tag_t pos):
|
||||
return False
|
||||
|
||||
property head:
|
||||
"""The token predicted by the parser to be the head of the current token."""
|
||||
def __get__(self):
|
||||
|
@ -267,16 +272,10 @@ cdef class Token:
|
|||
cdef unicode py_ustr = self._seq.vocab.strings[t.lemma]
|
||||
return py_ustr
|
||||
|
||||
property pos_:
|
||||
property tag_:
|
||||
def __get__(self):
|
||||
return self._seq.vocab.strings[self.pos]
|
||||
return self._seq.tag_names[self.tag]
|
||||
|
||||
property fine_pos_:
|
||||
property dep_:
|
||||
def __get__(self):
|
||||
return self._seq.vocab.strings[self.fine_pos]
|
||||
|
||||
property dep_tag_:
|
||||
def __get__(self):
|
||||
return self._seq.vocab.strings[self.dep_tag]
|
||||
|
||||
|
||||
return self._seq.dep_names[self.dep]
|
||||
|
|
140
spacy/vocab.pyx
140
spacy/vocab.pyx
|
@ -1,16 +1,20 @@
|
|||
from libc.stdio cimport fopen, fclose, fread, fwrite, FILE
|
||||
from libc.string cimport memset
|
||||
from libc.stdint cimport int32_t
|
||||
|
||||
import bz2
|
||||
from os import path
|
||||
import codecs
|
||||
|
||||
from .lexeme cimport EMPTY_LEXEME
|
||||
from .lexeme cimport set_lex_struct_props
|
||||
from .lexeme cimport Lexeme_cinit
|
||||
from .lexeme cimport Lexeme
|
||||
from .strings cimport slice_unicode
|
||||
from .strings cimport hash_string
|
||||
from .orth cimport word_shape
|
||||
|
||||
from cymem.cymem cimport Address
|
||||
|
||||
|
||||
DEF MAX_VEC_SIZE = 100000
|
||||
|
||||
|
@ -34,12 +38,15 @@ cdef class Vocab:
|
|||
if data_dir is not None:
|
||||
if not path.exists(data_dir):
|
||||
raise IOError("Directory %s not found -- cannot load Vocab." % data_dir)
|
||||
assert EMPTY_LEXEME.vec != NULL
|
||||
if data_dir is not None:
|
||||
if not path.isdir(data_dir):
|
||||
raise IOError("Path %s is a file, not a dir -- cannot load Vocab." % data_dir)
|
||||
self.strings.load(path.join(data_dir, 'strings.txt'))
|
||||
self.load_lexemes(path.join(data_dir, 'lexemes.bin'))
|
||||
#self.load_vectors(path.join(data_dir, 'deps.words'))
|
||||
self.load_vectors(path.join(data_dir, 'vec.bin'))
|
||||
for i in range(self.lexemes.size()):
|
||||
assert self.lexemes[i].vec != NULL, repr(self.strings[self.lexemes[i].sic])
|
||||
|
||||
def __len__(self):
|
||||
"""The current number of lexemes stored."""
|
||||
|
@ -52,13 +59,15 @@ cdef class Vocab:
|
|||
cdef LexemeC* lex
|
||||
lex = <LexemeC*>self._map.get(c_str.key)
|
||||
if lex != NULL:
|
||||
assert lex.vec != NULL
|
||||
return lex
|
||||
if c_str.n < 3:
|
||||
mem = self.mem
|
||||
cdef unicode py_str = c_str.chars[:c_str.n]
|
||||
lex = <LexemeC*>mem.alloc(sizeof(LexemeC), 1)
|
||||
props = self.lexeme_props_getter(py_str)
|
||||
set_lex_struct_props(lex, props, self.strings)
|
||||
set_lex_struct_props(lex, props, self.strings, EMPTY_VEC)
|
||||
assert lex.vec != NULL
|
||||
if mem is self.mem:
|
||||
lex.id = self.lexemes.size()
|
||||
self._add_lex_to_vocab(c_str.key, lex)
|
||||
|
@ -98,7 +107,7 @@ cdef class Vocab:
|
|||
lexeme = self.get(self.mem, &c_str)
|
||||
else:
|
||||
raise ValueError("Vocab unable to map type: %s. Maps unicode --> int or int --> unicode" % str(type(id_or_string)))
|
||||
return Lexeme_cinit(lexeme, self.strings)
|
||||
return Lexeme.from_ptr(lexeme, self.strings)
|
||||
|
||||
def __setitem__(self, unicode py_str, dict props):
|
||||
cdef UniStr c_str
|
||||
|
@ -109,7 +118,8 @@ cdef class Vocab:
|
|||
lex = <LexemeC*>self.mem.alloc(sizeof(LexemeC), 1)
|
||||
lex.id = self.lexemes.size()
|
||||
self._add_lex_to_vocab(c_str.key, lex)
|
||||
set_lex_struct_props(lex, props, self.strings)
|
||||
set_lex_struct_props(lex, props, self.strings, EMPTY_VEC)
|
||||
assert lex.vec != NULL
|
||||
assert lex.sic < 1000000
|
||||
|
||||
def dump(self, loc):
|
||||
|
@ -147,8 +157,9 @@ cdef class Vocab:
|
|||
if st != 1:
|
||||
break
|
||||
lexeme = <LexemeC*>self.mem.alloc(sizeof(LexemeC), 1)
|
||||
lexeme.vec = EMPTY_VEC
|
||||
# Copies data from the file into the lexeme
|
||||
st = fread(lexeme, sizeof(LexemeC), 1, fp)
|
||||
lexeme.vec = EMPTY_VEC
|
||||
if st != 1:
|
||||
break
|
||||
self._map.set(key, lexeme)
|
||||
|
@ -157,29 +168,98 @@ cdef class Vocab:
|
|||
self.lexemes[lexeme.id] = lexeme
|
||||
i += 1
|
||||
fclose(fp)
|
||||
|
||||
|
||||
def load_vectors(self, loc):
|
||||
cdef int i
|
||||
cdef unicode line
|
||||
cdef unicode word
|
||||
cdef unicode val_str
|
||||
cdef hash_t key
|
||||
cdef LexemeC* lex
|
||||
file_ = _CFile(loc, 'rb')
|
||||
cdef int32_t word_len
|
||||
cdef int32_t vec_len
|
||||
cdef float* vec
|
||||
|
||||
with codecs.open(loc, 'r', 'utf8') as file_:
|
||||
for line in file_:
|
||||
pieces = line.split()
|
||||
word = pieces.pop(0)
|
||||
if len(pieces) >= MAX_VEC_SIZE:
|
||||
sizes = (len(pieces), MAX_VEC_SIZE)
|
||||
msg = ("Your vector is %d elements."
|
||||
"The compile-time limit is %d elements." % sizes)
|
||||
raise ValueError(msg)
|
||||
key = hash_string(word)
|
||||
lex = <LexemeC*>self._map.get(key)
|
||||
if lex is not NULL:
|
||||
vec = <float*>self.mem.alloc(len(pieces), sizeof(float))
|
||||
for i, val_str in enumerate(pieces):
|
||||
vec[i] = float(val_str)
|
||||
lex.vec = vec
|
||||
cdef Address mem
|
||||
cdef id_t string_id
|
||||
cdef bytes py_word
|
||||
cdef vector[float*] vectors
|
||||
cdef int i
|
||||
while True:
|
||||
try:
|
||||
file_.read(&word_len, sizeof(word_len), 1)
|
||||
except IOError:
|
||||
break
|
||||
file_.read(&vec_len, sizeof(vec_len), 1)
|
||||
|
||||
mem = Address(word_len, sizeof(char))
|
||||
chars = <char*>mem.ptr
|
||||
vec = <float*>self.mem.alloc(vec_len, sizeof(float))
|
||||
|
||||
file_.read(chars, sizeof(char), word_len)
|
||||
file_.read(vec, sizeof(float), vec_len)
|
||||
|
||||
string_id = self.strings[chars[:word_len]]
|
||||
while string_id >= vectors.size():
|
||||
vectors.push_back(EMPTY_VEC)
|
||||
assert vec != NULL
|
||||
vectors[string_id] = vec
|
||||
cdef LexemeC* lex
|
||||
for i in range(self.lexemes.size()):
|
||||
# Cast away the const, cos we can modify our lexemes
|
||||
lex = <LexemeC*>self.lexemes[i]
|
||||
if lex.sic < vectors.size():
|
||||
lex.vec = vectors[lex.sic]
|
||||
else:
|
||||
lex.vec = EMPTY_VEC
|
||||
assert lex.vec != NULL
|
||||
|
||||
|
||||
def write_binary_vectors(in_loc, out_loc):
|
||||
cdef _CFile out_file = _CFile(out_loc, 'wb')
|
||||
cdef Address mem
|
||||
cdef int32_t word_len
|
||||
cdef int32_t vec_len
|
||||
cdef char* chars
|
||||
with bz2.BZ2File(in_loc, 'r') as file_:
|
||||
for line in file_:
|
||||
pieces = line.split()
|
||||
word = pieces.pop(0)
|
||||
mem = Address(len(pieces), sizeof(float))
|
||||
vec = <float*>mem.ptr
|
||||
for i, val_str in enumerate(pieces):
|
||||
vec[i] = float(val_str)
|
||||
|
||||
word_len = len(word)
|
||||
vec_len = len(pieces)
|
||||
|
||||
out_file.write(sizeof(word_len), 1, &word_len)
|
||||
out_file.write(sizeof(vec_len), 1, &vec_len)
|
||||
|
||||
chars = <char*>word
|
||||
out_file.write(sizeof(char), len(word), chars)
|
||||
out_file.write(sizeof(float), vec_len, vec)
|
||||
|
||||
|
||||
cdef class _CFile:
|
||||
cdef FILE* fp
|
||||
def __init__(self, loc, mode):
|
||||
cdef bytes bytes_loc = loc.encode('utf8') if type(loc) == unicode else loc
|
||||
self.fp = fopen(<char*>bytes_loc, mode)
|
||||
if self.fp == NULL:
|
||||
raise IOError
|
||||
|
||||
def __dealloc__(self):
|
||||
fclose(self.fp)
|
||||
|
||||
def close(self):
|
||||
fclose(self.fp)
|
||||
|
||||
cdef int read(self, void* dest, size_t elem_size, size_t n) except -1:
|
||||
st = fread(dest, elem_size, n, self.fp)
|
||||
if st != n:
|
||||
raise IOError
|
||||
|
||||
cdef int write(self, size_t elem_size, size_t n, void* data) except -1:
|
||||
st = fwrite(data, elem_size, n, self.fp)
|
||||
if st != n:
|
||||
raise IOError
|
||||
|
||||
cdef int write_unicode(self, unicode value):
|
||||
cdef bytes py_bytes = value.encode('utf8')
|
||||
cdef char* chars = <char*>py_bytes
|
||||
self.write(sizeof(char), len(py_bytes), chars)
|
||||
|
|
Loading…
Reference in New Issue
Block a user