* Work on word vectors, and other stuff

This commit is contained in:
Matthew Honnibal 2015-01-17 16:21:17 +11:00
parent 7e69e17161
commit 6c7e44140b
17 changed files with 280 additions and 154 deletions

View File

@ -1,5 +1,6 @@
from __future__ import unicode_literals
from os import path
import re
from .. import orth
from ..vocab import Vocab
@ -11,6 +12,9 @@ from .pos import POS_TAGS
from .attrs import get_flags
from ..util import read_lang_data
def get_lex_props(string):
return {
'flags': get_flags(string),
@ -64,11 +68,16 @@ class English(object):
tag_names = list(POS_TAGS.keys())
tag_names.sort()
if data_dir is None:
self.tokenizer = Tokenizer(self.vocab, {}, None, None, None,
POS_TAGS, tag_names)
tok_rules = {}
prefix_re = None
suffix_re = None
infix_re = None
else:
self.tokenizer = Tokenizer.from_dir(self.vocab, path.join(data_dir, 'tokenizer'),
POS_TAGS, tag_names)
tok_data_dir = path.join(data_dir, 'tokenizer')
tok_rules, prefix_re, suffix_re, infix_re = read_lang_data(tok_data_dir)
self.tokenizer = Tokenizer(self.vocab, tok_rules, re.compile(prefix_re),
re.compile(suffix_re), re.compile(infix_re),
POS_TAGS, tag_names)
self.strings = self.vocab.strings
self._tagger = None
self._parser = None
@ -100,11 +109,11 @@ class English(object):
Returns:
tokens (spacy.tokens.Tokens):
"""
tokens = self.tokenizer.tokenize(text)
tokens = self.tokenizer(text)
if tag:
self.tagger(tokens)
if parse:
self.parser.parse(tokens)
self.parser(tokens)
return tokens
@property

View File

@ -1,18 +1,16 @@
from ..typedefs cimport FLAG0, FLAG1, FLAG2, FLAG3, FLAG4, FLAG5, FLAG6, FLAG7
from ..typedefs cimport FLAG8, FLAG9
from ..typedefs cimport ID as _ID
from ..typedefs cimport SIC as _SIC
from ..typedefs cimport SHAPE as _SHAPE
from ..typedefs cimport NORM1 as _NORM1
from ..typedefs cimport NORM2 as _NORM2
from ..typedefs cimport CLUSTER as _CLUSTER
from ..typedefs cimport PREFIX as _PREFIX
from ..typedefs cimport SUFFIX as _SUFFIX
from ..typedefs cimport LEMMA as _LEMMA
from ..typedefs cimport POS as _POS
from ..attrs cimport FLAG0, FLAG1, FLAG2, FLAG3, FLAG4, FLAG5, FLAG6, FLAG7
from ..attrs cimport FLAG8, FLAG9
from ..attrs cimport SIC as _SIC
from ..attrs cimport SHAPE as _SHAPE
from ..attrs cimport NORM1 as _NORM1
from ..attrs cimport NORM2 as _NORM2
from ..attrs cimport CLUSTER as _CLUSTER
from ..attrs cimport PREFIX as _PREFIX
from ..attrs cimport SUFFIX as _SUFFIX
from ..attrs cimport LEMMA as _LEMMA
from ..attrs cimport POS as _POS
# Work around the lack of global cpdef variables
cpdef enum:
IS_ALPHA = FLAG0
IS_ASCII = FLAG1
@ -25,7 +23,6 @@ cpdef enum:
LIKE_URL = FLAG8
LIKE_NUM = FLAG9
ID = _ID
SIC = _SIC
SHAPE = _SHAPE
NORM1 = _NORM1

View File

@ -4,33 +4,52 @@ import tarfile
import shutil
import requests
PARSER_URL = 'https://s3-us-west-1.amazonaws.com/media.spacynlp.com/en.tgz'
PARSER_URL = 'http://s3-us-west-1.amazonaws.com/media.spacynlp.com/en.tgz'
DEST_DIR = path.join(path.dirname(__file__), 'data', 'deps')
DEP_VECTORS_URL = 'http://u.cs.biu.ac.il/~yogo/data/syntemb/deps.words.bz2'
DEST_DIR = path.join(path.dirname(__file__), 'data')
def download_file(url):
local_filename = url.split('/')[-1]
return path.join(DEST_DIR, local_filename)
# NOTE the stream=True parameter
r = requests.get(url, stream=True)
print "Download %s" % url
i = 0
with open(local_filename, 'wb') as f:
for chunk in r.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
f.write(chunk)
f.flush()
print i
i += 1
return local_filename
def main():
if not os.path.exists(DEST_DIR):
os.mkdir(DEST_DIR)
assert not path.exists(path.join(DEST_DIR, 'en'))
def install_parser_model(url, dest_dir):
if not os.path.exists(dest_dir):
os.mkdir(dest_dir)
assert not path.exists(path.join(dest_dir, 'en'))
filename = download_file(URL)
filename = download_file(url)
t = tarfile.open(filename, mode=":gz")
t.extractall(DEST_DIR)
shutil.move(path.join(DEST_DIR, 'en', 'deps', 'model'), DEST_DIR)
shutil.move(path.join(DEST_DIR, 'en', 'deps', 'config.json'), DEST_DIR)
shutil.rmtree(path.join(DEST_DIR, 'en'))
t.extractall(dest_dir)
shutil.move(path.join(dest_dir, 'en', 'deps', 'model'), dest_dir)
shutil.move(path.join(dest_dir, 'en', 'deps', 'config.json'), dest_dir)
shutil.rmtree(path.join(dest_dir, 'en'))
def install_dep_vectors(url, dest_dir):
if not os.path.exists(dest_dir):
os.mkdir(dest_dir)
filename = download_file(url)
shutil.move(filename, path.join(dest_dir, 'vec.bz2'))
def main():
#install_parser_model(PARSER_URL, path.join(DEST_DIR, 'deps'))
install_dep_vectors(DEP_VECTORS_URL, path.join(DEST_DIR, 'vocab'))
if __name__ == '__main__':

View File

@ -247,11 +247,12 @@ cdef class EnPosTagger:
cdef atom_t[N_CONTEXT_FIELDS] context
cdef const weight_t* scores
for i in range(tokens.length):
if tokens.data[i].fine_pos == 0:
if tokens.data[i].pos == 0:
fill_context(context, i, tokens.data)
scores = self.model.score(context)
tokens.data[i].fine_pos = arg_max(scores, self.model.n_classes)
tokens.data[i].tag = arg_max(scores, self.model.n_classes)
self.set_morph(i, tokens.data)
tokens.pos_scheme = self.tag_map
def train(self, Tokens tokens, object golds):
cdef int i
@ -263,13 +264,13 @@ cdef class EnPosTagger:
scores = self.model.score(context)
guess = arg_max(scores, self.model.n_classes)
self.model.update(context, guess, golds[i], guess != golds[i])
tokens.data[i].fine_pos = guess
tokens.data[i].tag = guess
self.set_morph(i, tokens.data)
correct += guess == golds[i]
return correct
cdef int set_morph(self, const int i, TokenC* tokens) except -1:
cdef const PosTag* tag = &self.tags[tokens[i].fine_pos]
cdef const PosTag* tag = &self.tags[tokens[i].tag]
tokens[i].pos = tag.pos
cached = <_CachedMorph*>self._morph_cache.get(tag.id, tokens[i].lex.sic)
if cached is NULL:

View File

@ -3,39 +3,70 @@ from .typedefs cimport ID, SIC, NORM1, NORM2, SHAPE, PREFIX, SUFFIX, LENGTH, CLU
from .structs cimport LexemeC
from .strings cimport StringStore
from numpy cimport ndarray
cdef LexemeC EMPTY_LEXEME
cdef int set_lex_struct_props(LexemeC* lex, dict props, StringStore strings) except -1
cdef int set_lex_struct_props(LexemeC* lex, dict props, StringStore strings,
const float* empty_vec) except -1
cdef class Lexeme:
cdef const float* vec
cdef readonly ndarray vec
cdef readonly flags_t flags
cdef readonly attr_t id
cdef readonly attr_t length
cdef readonly attr_t sic
cdef readonly unicode norm1
cdef readonly unicode norm2
cdef readonly unicode shape
cdef readonly unicode prefix
cdef readonly unicode suffix
cdef readonly attr_t norm1
cdef readonly attr_t norm2
cdef readonly attr_t shape
cdef readonly attr_t prefix
cdef readonly attr_t suffix
cdef readonly attr_t sic_id
cdef readonly attr_t norm1_id
cdef readonly attr_t norm2_id
cdef readonly attr_t shape_id
cdef readonly attr_t prefix_id
cdef readonly attr_t suffix_id
cdef readonly unicode sic_
cdef readonly unicode norm1_
cdef readonly unicode norm2_
cdef readonly unicode shape_
cdef readonly unicode prefix_
cdef readonly unicode suffix_
cdef readonly attr_t cluster
cdef readonly float prob
cdef readonly float sentiment
# Workaround for an apparent bug in the way the decorator is handled ---
# TODO: post bug report / patch to Cython.
@staticmethod
cdef inline Lexeme from_ptr(const LexemeC* ptr, StringStore strings):
cdef Lexeme py = Lexeme.__new__(Lexeme, 300)
for i in range(300):
py.vec[i] = ptr.vec[i]
py.flags = ptr.flags
py.id = ptr.id
py.length = ptr.length
cdef Lexeme Lexeme_cinit(const LexemeC* c, StringStore strings)
py.sic = ptr.sic
py.norm1 = ptr.norm1
py.norm2 = ptr.norm2
py.shape = ptr.shape
py.prefix = ptr.prefix
py.suffix = ptr.suffix
py.sic_ = strings[ptr.sic]
py.norm1_ = strings[ptr.norm1]
py.norm2_ = strings[ptr.norm2]
py.shape_ = strings[ptr.shape]
py.prefix_ = strings[ptr.prefix]
py.suffix_ = strings[ptr.suffix]
py.cluster = ptr.cluster
py.prob = ptr.prob
py.sentiment = ptr.sentiment
return py
cdef inline bint check_flag(const LexemeC* lexeme, attr_id_t flag_id) nogil:

View File

@ -7,13 +7,14 @@ from libc.string cimport memset
from .orth cimport word_shape
from .typedefs cimport attr_t
import numpy
memset(&EMPTY_LEXEME, 0, sizeof(LexemeC))
cdef int set_lex_struct_props(LexemeC* lex, dict props, StringStore string_store) except -1:
cdef int set_lex_struct_props(LexemeC* lex, dict props, StringStore string_store,
const float* empty_vec) except -1:
lex.length = props['length']
lex.sic = string_store[props['sic']]
lex.norm1 = string_store[props['norm1']]
@ -27,39 +28,10 @@ cdef int set_lex_struct_props(LexemeC* lex, dict props, StringStore string_store
lex.sentiment = props['sentiment']
lex.flags = props['flags']
lex.vec = empty_vec
cdef class Lexeme:
"""A dummy docstring"""
def __init__(self):
pass
cdef Lexeme Lexeme_cinit(const LexemeC* c, StringStore strings):
cdef Lexeme py = Lexeme.__new__(Lexeme)
py.vec = c.vec
py.flags = c.flags
py.id = c.id
py.length = c.length
py.sic = c.sic
py.norm1 = strings[c.norm1]
py.norm2 = strings[c.norm2]
py.shape = strings[c.shape]
py.prefix = strings[c.prefix]
py.suffix = strings[c.suffix]
py.sic_id = c.sic
py.norm1_id = c.norm1
py.norm2_id = c.norm2
py.shape_id = c.shape
py.prefix_id = c.prefix
py.suffix_id = c.suffix
py.cluster = c.cluster
py.prob = c.prob
py.sentiment = c.sentiment
return py
def __cinit__(self, int vec_size):
self.vec = numpy.ndarray(shape=(vec_size,), dtype=numpy.float32)

View File

@ -137,8 +137,42 @@ cpdef unicode word_shape(unicode string):
return ''.join(shape)
cpdef unicode norm1(unicode string, lower_pc=0.0, upper_pc=0.0, title_pc=0.0):
"""Apply level 1 normalization:
* Case is canonicalized, using frequency statistics
* Unicode mapped to ascii, via unidecode
* Regional spelling variations are normalized
"""
pass
cpdef bytes asciied(unicode string):
cdef str stripped = unidecode(string)
if not stripped:
return b'???'
return stripped.encode('ascii')
# Exceptions --- do not convert these
_uk_us_except = set([
'our',
'ours',
'four',
'fours',
'your',
'yours',
'hour',
'hours',
'course',
'rise',
])
def uk_to_usa(unicode string):
if not string.islower():
return string
if string in _uk_us_except:
return string
our = re.compile(r'ours?$')
string = our.sub('or', string)
return string

View File

@ -44,12 +44,12 @@ cdef struct TokenC:
const LexemeC* lex
Morphology morph
univ_tag_t pos
int fine_pos
int tag
int idx
int lemma
int sense
int head
int dep_tag
int dep
uint32_t l_kids
uint32_t r_kids

View File

@ -28,7 +28,7 @@ cdef inline void fill_token(atom_t* context, const TokenC* token) nogil:
else:
context[0] = token.lex.sic
context[1] = token.lemma
context[2] = token.fine_pos
context[2] = token.tag
context[3] = token.lex.cluster
# We've read in the string little-endian, so now we can take & (2**n)-1
# to get the first n bits of the cluster.
@ -44,7 +44,7 @@ cdef inline void fill_token(atom_t* context, const TokenC* token) nogil:
# the source that are set to 1.
context[4] = token.lex.cluster & 63
context[5] = token.lex.cluster & 15
context[6] = token.dep_tag if has_head(token) else 0
context[6] = token.dep if has_head(token) else 0
cdef int fill_context(atom_t* context, State* state) except -1:

View File

@ -12,7 +12,7 @@ DEF NON_MONOTONIC = True
cdef int add_dep(State *s, int head, int child, int label) except -1:
cdef int dist = head - child
s.sent[child].head = dist
s.sent[child].dep_tag = label
s.sent[child].dep = label
# Keep a bit-vector tracking child dependencies. If a word has a child at
# offset i from it, set that bit (tracking left and right separately)
if child > head:
@ -38,7 +38,7 @@ cdef int push_stack(State *s) except -1:
if at_eol(s):
while s.stack_len != 0:
if not has_head(get_s0(s)):
get_s0(s).dep_tag = 0
get_s0(s).dep = 0
pop_stack(s)

View File

@ -123,7 +123,7 @@ cdef class TransitionSystem:
if t.move == SHIFT:
# Set the dep label, in case we need it after we reduce
if NON_MONOTONIC:
get_s0(s).dep_tag = t.label
get_s0(s).dep = t.label
push_stack(s)
elif t.move == LEFT:
add_dep(s, s.i, s.stack[0], t.label)
@ -132,7 +132,7 @@ cdef class TransitionSystem:
add_dep(s, s.stack[0], s.i, t.label)
push_stack(s)
elif t.move == REDUCE:
add_dep(s, s.stack[-1], s.stack[0], get_s0(s).dep_tag)
add_dep(s, s.stack[-1], s.stack[0], get_s0(s).dep)
pop_stack(s)
else:
raise Exception(t.move)

View File

@ -9,5 +9,3 @@ cdef class GreedyParser:
cdef object cfg
cdef readonly Model model
cdef TransitionSystem moves
cpdef int parse(self, Tokens tokens) except -1

View File

@ -65,7 +65,7 @@ cdef class GreedyParser:
hasty_templ, full_templ = get_templates(self.cfg.features)
self.model = Model(self.moves.n_moves, full_templ, model_dir)
cpdef int parse(self, Tokens tokens) except -1:
def __call__(self, Tokens tokens):
cdef:
Transition guess
uint64_t state_key

View File

@ -28,8 +28,6 @@ cdef class Tokenizer:
cdef object _infix_re
cpdef Tokens tokens_from_list(self, list strings)
cpdef Tokens tokenize(self, unicode text)
cdef int _try_cache(self, int idx, hash_t key, Tokens tokens) except -1
cdef int _tokenize(self, Tokens tokens, UniStr* span, int start, int end) except -1

View File

@ -31,18 +31,6 @@ cdef class Tokenizer:
self.vocab = vocab
self._load_special_tokenization(rules, pos_tags, tag_names)
@classmethod
def from_dir(cls, Vocab vocab, object data_dir, object pos_tags, object tag_names):
if not path.exists(data_dir):
raise IOError("Directory %s not found -- cannot load Tokenizer." % data_dir)
if not path.isdir(data_dir):
raise IOError("Path %s is a file, not a dir -- cannot load Tokenizer." % data_dir)
assert path.exists(data_dir) and path.isdir(data_dir)
rules, prefix_re, suffix_re, infix_re = util.read_lang_data(data_dir)
return cls(vocab, rules, re.compile(prefix_re), re.compile(suffix_re),
re.compile(infix_re), pos_tags, tag_names)
cpdef Tokens tokens_from_list(self, list strings):
cdef int length = sum([len(s) for s in strings])
cdef Tokens tokens = Tokens(self.vocab, length)
@ -57,7 +45,7 @@ cdef class Tokenizer:
idx += len(py_string) + 1
return tokens
cpdef Tokens tokenize(self, unicode string):
def __call__(self, unicode string):
"""Tokenize a string.
The tokenization rules are defined in three places:
@ -257,7 +245,7 @@ cdef class Tokenizer:
tokens[i].lemma = self.vocab.strings[lemma]
if 'pos' in props:
# TODO: Clean up this mess...
tokens[i].fine_pos = tag_names.index(props['pos'])
tokens[i].tag = tag_names.index(props['pos'])
tokens[i].pos = tag_map[props['pos']][0]
# These are defaults, which can be over-ridden by the
# token-specific props.

View File

@ -198,9 +198,8 @@ cdef class Token:
self.sentiment = t.lex.sentiment
self.flags = t.lex.flags
self.lemma = t.lemma
self.pos = t.pos
self.fine_pos = t.fine_pos
self.dep_tag = t.dep_tag
self.tag = t.tag
self.dep = t.dep
def __unicode__(self):
cdef const TokenC* t = &self._seq.data[self.i]
@ -220,6 +219,12 @@ cdef class Token:
"""
return self._seq.data[self.i].lex.length
def check_flag(self, attr_id_t flag):
return False
def is_pos(self, univ_tag_t pos):
return False
property head:
"""The token predicted by the parser to be the head of the current token."""
def __get__(self):
@ -267,16 +272,10 @@ cdef class Token:
cdef unicode py_ustr = self._seq.vocab.strings[t.lemma]
return py_ustr
property pos_:
property tag_:
def __get__(self):
return self._seq.vocab.strings[self.pos]
return self._seq.tag_names[self.tag]
property fine_pos_:
property dep_:
def __get__(self):
return self._seq.vocab.strings[self.fine_pos]
property dep_tag_:
def __get__(self):
return self._seq.vocab.strings[self.dep_tag]
return self._seq.dep_names[self.dep]

View File

@ -1,16 +1,20 @@
from libc.stdio cimport fopen, fclose, fread, fwrite, FILE
from libc.string cimport memset
from libc.stdint cimport int32_t
import bz2
from os import path
import codecs
from .lexeme cimport EMPTY_LEXEME
from .lexeme cimport set_lex_struct_props
from .lexeme cimport Lexeme_cinit
from .lexeme cimport Lexeme
from .strings cimport slice_unicode
from .strings cimport hash_string
from .orth cimport word_shape
from cymem.cymem cimport Address
DEF MAX_VEC_SIZE = 100000
@ -34,12 +38,15 @@ cdef class Vocab:
if data_dir is not None:
if not path.exists(data_dir):
raise IOError("Directory %s not found -- cannot load Vocab." % data_dir)
assert EMPTY_LEXEME.vec != NULL
if data_dir is not None:
if not path.isdir(data_dir):
raise IOError("Path %s is a file, not a dir -- cannot load Vocab." % data_dir)
self.strings.load(path.join(data_dir, 'strings.txt'))
self.load_lexemes(path.join(data_dir, 'lexemes.bin'))
#self.load_vectors(path.join(data_dir, 'deps.words'))
self.load_vectors(path.join(data_dir, 'vec.bin'))
for i in range(self.lexemes.size()):
assert self.lexemes[i].vec != NULL, repr(self.strings[self.lexemes[i].sic])
def __len__(self):
"""The current number of lexemes stored."""
@ -52,13 +59,15 @@ cdef class Vocab:
cdef LexemeC* lex
lex = <LexemeC*>self._map.get(c_str.key)
if lex != NULL:
assert lex.vec != NULL
return lex
if c_str.n < 3:
mem = self.mem
cdef unicode py_str = c_str.chars[:c_str.n]
lex = <LexemeC*>mem.alloc(sizeof(LexemeC), 1)
props = self.lexeme_props_getter(py_str)
set_lex_struct_props(lex, props, self.strings)
set_lex_struct_props(lex, props, self.strings, EMPTY_VEC)
assert lex.vec != NULL
if mem is self.mem:
lex.id = self.lexemes.size()
self._add_lex_to_vocab(c_str.key, lex)
@ -98,7 +107,7 @@ cdef class Vocab:
lexeme = self.get(self.mem, &c_str)
else:
raise ValueError("Vocab unable to map type: %s. Maps unicode --> int or int --> unicode" % str(type(id_or_string)))
return Lexeme_cinit(lexeme, self.strings)
return Lexeme.from_ptr(lexeme, self.strings)
def __setitem__(self, unicode py_str, dict props):
cdef UniStr c_str
@ -109,7 +118,8 @@ cdef class Vocab:
lex = <LexemeC*>self.mem.alloc(sizeof(LexemeC), 1)
lex.id = self.lexemes.size()
self._add_lex_to_vocab(c_str.key, lex)
set_lex_struct_props(lex, props, self.strings)
set_lex_struct_props(lex, props, self.strings, EMPTY_VEC)
assert lex.vec != NULL
assert lex.sic < 1000000
def dump(self, loc):
@ -147,8 +157,9 @@ cdef class Vocab:
if st != 1:
break
lexeme = <LexemeC*>self.mem.alloc(sizeof(LexemeC), 1)
lexeme.vec = EMPTY_VEC
# Copies data from the file into the lexeme
st = fread(lexeme, sizeof(LexemeC), 1, fp)
lexeme.vec = EMPTY_VEC
if st != 1:
break
self._map.set(key, lexeme)
@ -157,29 +168,98 @@ cdef class Vocab:
self.lexemes[lexeme.id] = lexeme
i += 1
fclose(fp)
def load_vectors(self, loc):
cdef int i
cdef unicode line
cdef unicode word
cdef unicode val_str
cdef hash_t key
cdef LexemeC* lex
file_ = _CFile(loc, 'rb')
cdef int32_t word_len
cdef int32_t vec_len
cdef float* vec
with codecs.open(loc, 'r', 'utf8') as file_:
for line in file_:
pieces = line.split()
word = pieces.pop(0)
if len(pieces) >= MAX_VEC_SIZE:
sizes = (len(pieces), MAX_VEC_SIZE)
msg = ("Your vector is %d elements."
"The compile-time limit is %d elements." % sizes)
raise ValueError(msg)
key = hash_string(word)
lex = <LexemeC*>self._map.get(key)
if lex is not NULL:
vec = <float*>self.mem.alloc(len(pieces), sizeof(float))
for i, val_str in enumerate(pieces):
vec[i] = float(val_str)
lex.vec = vec
cdef Address mem
cdef id_t string_id
cdef bytes py_word
cdef vector[float*] vectors
cdef int i
while True:
try:
file_.read(&word_len, sizeof(word_len), 1)
except IOError:
break
file_.read(&vec_len, sizeof(vec_len), 1)
mem = Address(word_len, sizeof(char))
chars = <char*>mem.ptr
vec = <float*>self.mem.alloc(vec_len, sizeof(float))
file_.read(chars, sizeof(char), word_len)
file_.read(vec, sizeof(float), vec_len)
string_id = self.strings[chars[:word_len]]
while string_id >= vectors.size():
vectors.push_back(EMPTY_VEC)
assert vec != NULL
vectors[string_id] = vec
cdef LexemeC* lex
for i in range(self.lexemes.size()):
# Cast away the const, cos we can modify our lexemes
lex = <LexemeC*>self.lexemes[i]
if lex.sic < vectors.size():
lex.vec = vectors[lex.sic]
else:
lex.vec = EMPTY_VEC
assert lex.vec != NULL
def write_binary_vectors(in_loc, out_loc):
cdef _CFile out_file = _CFile(out_loc, 'wb')
cdef Address mem
cdef int32_t word_len
cdef int32_t vec_len
cdef char* chars
with bz2.BZ2File(in_loc, 'r') as file_:
for line in file_:
pieces = line.split()
word = pieces.pop(0)
mem = Address(len(pieces), sizeof(float))
vec = <float*>mem.ptr
for i, val_str in enumerate(pieces):
vec[i] = float(val_str)
word_len = len(word)
vec_len = len(pieces)
out_file.write(sizeof(word_len), 1, &word_len)
out_file.write(sizeof(vec_len), 1, &vec_len)
chars = <char*>word
out_file.write(sizeof(char), len(word), chars)
out_file.write(sizeof(float), vec_len, vec)
cdef class _CFile:
cdef FILE* fp
def __init__(self, loc, mode):
cdef bytes bytes_loc = loc.encode('utf8') if type(loc) == unicode else loc
self.fp = fopen(<char*>bytes_loc, mode)
if self.fp == NULL:
raise IOError
def __dealloc__(self):
fclose(self.fp)
def close(self):
fclose(self.fp)
cdef int read(self, void* dest, size_t elem_size, size_t n) except -1:
st = fread(dest, elem_size, n, self.fp)
if st != n:
raise IOError
cdef int write(self, size_t elem_size, size_t n, void* data) except -1:
st = fwrite(data, elem_size, n, self.fp)
if st != n:
raise IOError
cdef int write_unicode(self, unicode value):
cdef bytes py_bytes = value.encode('utf8')
cdef char* chars = <char*>py_bytes
self.write(sizeof(char), len(py_bytes), chars)