Get spaCy train command working with neural network

* Integrate models into pipeline
* Add basic serialization (maybe incorrect)
* Fix pickle on vocab
This commit is contained in:
Matthew Honnibal 2017-05-17 12:04:50 +02:00
parent 3bf4a28d8d
commit 793430aa7a
16 changed files with 649 additions and 240 deletions

View File

@ -81,17 +81,19 @@ class CLI(object):
train_data=("location of JSON-formatted training data", "positional", None, str),
dev_data=("location of JSON-formatted development data (optional)", "positional", None, str),
n_iter=("number of iterations", "option", "n", int),
nsents=("number of sentences", "option", None, int),
parser_L1=("L1 regularization penalty for parser", "option", "L", float),
no_tagger=("Don't train tagger", "flag", "T", bool),
no_parser=("Don't train parser", "flag", "P", bool),
no_ner=("Don't train NER", "flag", "N", bool)
)
def train(self, lang, output_dir, train_data, dev_data=None, n_iter=15,
parser_L1=0.0, no_tagger=False, no_parser=False, no_ner=False):
nsents=0, parser_L1=0.0, no_tagger=False, no_parser=False, no_ner=False):
"""
Train a model. Expects data in spaCy's JSON format.
"""
cli_train(lang, output_dir, train_data, dev_data, n_iter, not no_tagger,
nsents = nsents or None
cli_train(lang, output_dir, train_data, dev_data, n_iter, nsents, not no_tagger,
not no_parser, not no_ner, parser_L1)
@plac.annotations(

View File

@ -3,9 +3,11 @@ from cymem.cymem cimport Pool
cdef class CFile:
cdef FILE* fp
cdef bint is_open
cdef unsigned char* data
cdef int is_open
cdef Pool mem
cdef int size # For compatibility with subclass
cdef int i # For compatibility with subclass
cdef int _capacity # For compatibility with subclass
cdef int read_into(self, void* dest, size_t number, size_t elem_size) except -1
@ -16,8 +18,13 @@ cdef class CFile:
cdef class StringCFile(CFile):
cdef class StringCFile:
cdef unsigned char* data
cdef int is_open
cdef Pool mem
cdef int size # For compatibility with subclass
cdef int i # For compatibility with subclass
cdef int _capacity # For compatibility with subclass
cdef int read_into(self, void* dest, size_t number, size_t elem_size) except -1

View File

@ -53,31 +53,43 @@ cdef class CFile:
cdef class StringCFile:
def __init__(self, mode, bytes data=b'', on_open_error=None):
def __init__(self, bytes data, mode, on_open_error=None):
self.mem = Pool()
self.is_open = 'w' in mode
self.is_open = 1 if 'w' in mode else 0
self._capacity = max(len(data), 8)
self.size = len(data)
self.i = 0
self.data = <unsigned char*>self.mem.alloc(1, self._capacity)
for i in range(len(data)):
self.data[i] = data[i]
def __dealloc__(self):
# Important to override this -- or
# we try to close a non-existant file pointer!
pass
def close(self):
self.is_open = False
def string_data(self):
return (self.data-self.size)[:self.size]
cdef bytes byte_string = b'\0' * (self.size)
bytes_ptr = <char*>byte_string
for i in range(self.size):
bytes_ptr[i] = self.data[i]
print(byte_string)
return byte_string
cdef int read_into(self, void* dest, size_t number, size_t elem_size) except -1:
memcpy(dest, self.data, elem_size * number)
self.data += elem_size * number
if self.i+(number * elem_size) < self.size:
memcpy(dest, &self.data[self.i], elem_size * number)
self.i += elem_size * number
cdef int write_from(self, void* src, size_t elem_size, size_t number) except -1:
write_size = number * elem_size
if (self.size + write_size) >= self._capacity:
self._capacity = (self.size + write_size) * 2
self.data = <unsigned char*>self.mem.realloc(self.data, self._capacity)
memcpy(&self.data[self.size], src, elem_size * number)
memcpy(&self.data[self.size], src, write_size)
self.size += write_size
cdef void* alloc_read(self, Pool mem, size_t number, size_t elem_size) except *:

View File

@ -4,16 +4,20 @@ from __future__ import unicode_literals, division, print_function
import json
from collections import defaultdict
import cytoolz
from pathlib import Path
import dill
from ..tokens.doc import Doc
from ..scorer import Scorer
from ..gold import GoldParse, merge_sents
from ..gold import read_json_file as read_gold_json
from ..util import prints
from .. import util
from .. import displacy
def train(language, output_dir, train_data, dev_data, n_iter, tagger, parser, ner,
parser_L1):
def train(language, output_dir, train_data, dev_data, n_iter, n_sents,
tagger, parser, ner, parser_L1):
output_path = util.ensure_path(output_dir)
train_path = util.ensure_path(train_data)
dev_path = util.ensure_path(dev_data)
@ -39,10 +43,8 @@ def train(language, output_dir, train_data, dev_data, n_iter, tagger, parser, ne
'n_iter': n_iter,
'lang': language,
'features': lang.Defaults.tagger_features}
gold_train = list(read_gold_json(train_path))[:100]
gold_dev = list(read_gold_json(dev_path)) if dev_path else None
gold_dev = gold_dev[:100]
gold_train = list(read_gold_json(train_path, limit=n_sents))
gold_dev = list(read_gold_json(dev_path, limit=n_sents)) if dev_path else None
train_model(lang, gold_train, gold_dev, output_path, n_iter)
if gold_dev:
@ -63,34 +65,48 @@ def train_config(config):
def train_model(Language, train_data, dev_data, output_path, n_iter, **cfg):
print("Itn.\tDep. Loss\tUAS\tNER F.\tTag %\tToken %")
nlp = Language(pipeline=['token_vectors', 'tags', 'dependencies', 'entities'])
nlp = Language(pipeline=['token_vectors', 'tags', 'dependencies'])
# TODO: Get spaCy using Thinc's trainer and optimizer
with nlp.begin_training(train_data, **cfg) as (trainer, optimizer):
for itn, epoch in enumerate(trainer.epochs(n_iter)):
for itn, epoch in enumerate(trainer.epochs(n_iter, gold_preproc=True)):
losses = defaultdict(float)
for docs, golds in epoch:
to_render = []
for i, (docs, golds) in enumerate(epoch):
state = nlp.update(docs, golds, drop=0., sgd=optimizer)
losses['dep_loss'] += state.get('parser_loss', 0.0)
to_render.insert(0, nlp(docs[-1].text))
to_render[0].user_data['title'] = "Batch %d" % i
with Path('/tmp/entities.html').open('w') as file_:
html = displacy.render(to_render[:5], style='ent', page=True,
options={'compact': True})
file_.write(html)
with Path('/tmp/parses.html').open('w') as file_:
html = displacy.render(to_render[:5], style='dep', page=True,
options={'compact': True})
file_.write(html)
if dev_data:
dev_scores = trainer.evaluate(dev_data).scores
else:
dev_scores = defaultdict(float)
print_progress(itn, losses, dev_scores)
with (output_path / 'model.bin').open('wb') as file_:
dill.dump(nlp, file_, -1)
#nlp.to_disk(output_path, tokenizer=False)
def evaluate(Language, gold_tuples, output_path):
print("Load parser", output_path)
nlp = Language(path=output_path)
def evaluate(Language, gold_tuples, path):
with (path / 'model.bin').open('rb') as file_:
nlp = dill.load(file_)
scorer = Scorer()
for raw_text, sents in gold_tuples:
sents = merge_sents(sents)
for annot_tuples, brackets in sents:
if raw_text is None:
tokens = nlp.tokenizer.tokens_from_list(annot_tuples[1])
nlp.tagger(tokens)
nlp.parser(tokens)
nlp.entity(tokens)
tokens = Doc(nlp.vocab, words=annot_tuples[1])
state = None
for proc in nlp.pipeline:
state = proc(tokens, state=state)
else:
tokens = nlp(raw_text)
gold = GoldParse.from_annot_tuples(tokens, annot_tuples)

View File

@ -138,14 +138,16 @@ def _min_edit_path(cand_words, gold_words):
return prev_costs[n_gold], previous_row[-1]
def read_json_file(loc, docs_filter=None):
def read_json_file(loc, docs_filter=None, make_supertags=False, limit=None):
loc = ensure_path(loc)
if loc.is_dir():
for filename in loc.iterdir():
yield from read_json_file(loc / filename)
yield from read_json_file(loc / filename, limit=limit)
else:
with loc.open('r', encoding='utf8') as file_:
docs = ujson.load(file_)
if limit is not None:
docs = docs[:limit]
for doc in docs:
if docs_filter is not None and not docs_filter(doc):
continue
@ -169,11 +171,13 @@ def read_json_file(loc, docs_filter=None):
if labels[-1].lower() == 'root':
labels[-1] = 'ROOT'
ner.append(token.get('ner', '-'))
sents.append((
(ids, words, tags, heads, labels, ner),
sent.get('brackets', [])))
if make_supertags:
tags[-1] = '-'.join((tags[-1], labels[-1], ner[-1]))
sents.append([
[ids, words, tags, heads, labels, ner],
sent.get('brackets', [])])
if sents:
yield (paragraph.get('raw', None), sents)
yield [paragraph.get('raw', None), sents]
def _iob_to_biluo(tags):

View File

@ -1,6 +1,7 @@
# coding: utf8
from __future__ import absolute_import, unicode_literals
from contextlib import contextmanager
import dill
from .tokenizer import Tokenizer
from .vocab import Vocab
@ -188,10 +189,18 @@ class Language(object):
@contextmanager
def begin_training(self, gold_tuples, **cfg):
# Populate vocab
for _, annots_brackets in gold_tuples:
for annots, _ in annots_brackets:
for word in annots[1]:
_ = self.vocab[word]
# Handle crossing dependencies
gold_tuples = PseudoProjectivity.preprocess_training_data(gold_tuples)
contexts = []
for proc in self.pipeline:
if hasattr(proc, 'begin_training'):
context = proc.begin_training(gold_tuples, pipeline=self.pipeline)
context = proc.begin_training(gold_tuples,
pipeline=self.pipeline)
contexts.append(context)
trainer = Trainer(self, gold_tuples, **cfg)
yield trainer, trainer.optimizer
@ -221,15 +230,72 @@ class Language(object):
for doc, state in stream:
yield doc
def to_disk(self, path):
raise NotImplemented
def to_disk(self, path, **exclude):
"""Save the current state to a directory.
def from_disk(self, path):
raise NotImplemented
Args:
path: A path to a directory, which will be created if it doesn't
exist. Paths may be either strings or pathlib.Path-like
objects.
**exclude: Prevent named attributes from being saved.
"""
path = util.ensure_path(path)
if not path.exists():
path.mkdir()
if not path.is_dir():
raise IOError("Output path must be a directory")
props = {}
for name, value in self.__dict__.items():
if name in exclude:
continue
if hasattr(value, 'to_disk'):
value.to_disk(path / name)
else:
props[name] = value
with (path / 'props.pickle').open('wb') as file_:
dill.dump(props, file_)
def to_bytes(self, path):
raise NotImplemented
def from_disk(self, path, **exclude):
"""Load the current state from a directory.
def from_bytes(self, path):
raise NotImplemented
Args:
path: A path to a directory. Paths may be either strings or
pathlib.Path-like objects.
**exclude: Prevent named attributes from being saved.
"""
path = util.ensure_path(path)
for name in path.iterdir():
if name not in exclude and hasattr(self, str(name)):
getattr(self, name).from_disk(path / name)
with (path / 'props.pickle').open('rb') as file_:
bytes_data = file_.read()
self.from_bytes(bytes_data, **exclude)
return self
def to_bytes(self, **exclude):
"""Serialize the current state to a binary string.
Args:
path: A path to a directory. Paths may be either strings or
pathlib.Path-like objects.
**exclude: Prevent named attributes from being serialized.
"""
props = dict(self.__dict__)
for key in exclude:
if key in props:
props.pop(key)
return dill.dumps(props, -1)
def from_bytes(self, bytes_data, **exclude):
"""Load state from a binary string.
Args:
bytes_data (bytes): The data to load from.
**exclude: Prevent named attributes from being loaded.
"""
props = dill.loads(bytes_data)
for key, value in props.items():
if key not in exclude:
setattr(self, key, value)
return self

View File

@ -2,7 +2,7 @@ from .typedefs cimport attr_t, hash_t, flags_t, len_t, tag_t
from .attrs cimport attr_id_t
from .attrs cimport ID, ORTH, LOWER, NORM, SHAPE, PREFIX, SUFFIX, LENGTH, CLUSTER, LANG
from .structs cimport LexemeC
from .structs cimport LexemeC, SerializedLexemeC
from .strings cimport StringStore
from .vocab cimport Vocab
@ -22,7 +22,23 @@ cdef class Lexeme:
self.c = lex
self.vocab = vocab
self.orth = lex.orth
@staticmethod
cdef inline SerializedLexemeC c_to_bytes(const LexemeC* lex) nogil:
cdef SerializedLexemeC lex_data
buff = <const unsigned char*>&lex.flags
end = <const unsigned char*>&lex.l2_norm + sizeof(lex.l2_norm)
for i in range(sizeof(lex_data.data)):
lex_data.data[i] = buff[i]
return lex_data
@staticmethod
cdef inline void c_from_bytes(LexemeC* lex, SerializedLexemeC lex_data) nogil:
buff = <unsigned char*>&lex.flags
end = <unsigned char*>&lex.l2_norm + sizeof(lex.l2_norm)
for i in range(sizeof(lex_data.data)):
buff[i] = lex_data.data[i]
@staticmethod
cdef inline void set_struct_attr(LexemeC* lex, attr_id_t name, attr_t value) nogil:
if name < (sizeof(flags_t) * 8):

View File

@ -116,6 +116,29 @@ cdef class Lexeme:
return 0.0
return numpy.dot(self.vector, other.vector) / (self.vector_norm * other.vector_norm)
def to_bytes(self):
lex_data = Lexeme.c_to_bytes(self.c)
start = <const char*>&self.c.flags
end = <const char*>&self.c.l2_norm + sizeof(self.c.l2_norm)
assert (end-start) == sizeof(lex_data.data), (end-start, sizeof(lex_data.data))
byte_string = b'\0' * sizeof(lex_data.data)
byte_chars = <char*>byte_string
for i in range(sizeof(lex_data.data)):
byte_chars[i] = lex_data.data[i]
assert len(byte_string) == sizeof(lex_data.data), (len(byte_string),
sizeof(lex_data.data))
return byte_string
def from_bytes(self, bytes byte_string):
# This method doesn't really have a use-case --- wrote it for testing.
# Possibly delete? It puts the Lexeme out of synch with the vocab.
cdef SerializedLexemeC lex_data
assert len(byte_string) == sizeof(lex_data.data)
for i in range(len(byte_string)):
lex_data.data[i] = byte_string[i]
Lexeme.c_from_bytes(self.c, lex_data)
self.orth = self.c.orth
property has_vector:
def __get__(self):
cdef int i

View File

@ -26,10 +26,14 @@ from .syntax.beam_parser cimport BeamParser
from .syntax.ner cimport BiluoPushDown
from .syntax.arc_eager cimport ArcEager
from .tagger import Tagger
from .syntax.stateclass cimport StateClass
from .gold cimport GoldParse
from .morphology cimport Morphology
from .vocab cimport Vocab
from .attrs import ID, LOWER, PREFIX, SUFFIX, SHAPE, TAG, DEP
from .attrs import ID, LOWER, PREFIX, SUFFIX, SHAPE, TAG, DEP, POS
from ._ml import Tok2Vec, flatten, get_col, doc2feats
from .parts_of_speech import X
class TokenVectorEncoder(object):
@ -50,7 +54,7 @@ class TokenVectorEncoder(object):
docs = [docs]
tokvecs = self.predict(docs)
self.set_annotations(docs, tokvecs)
state = {} if state is not None else state
state = {} if state is None else state
state['tokvecs'] = tokvecs
return state
@ -58,7 +62,6 @@ class TokenVectorEncoder(object):
raise NotImplementedError
def predict(self, docs):
cdef Doc doc
feats = self.doc2feats(docs)
tokvecs = self.model(feats)
return tokvecs
@ -68,7 +71,7 @@ class TokenVectorEncoder(object):
for doc in docs:
doc.tensor = tokvecs[start : start + len(doc)]
start += len(doc)
def update(self, docs, golds, state=None,
drop=0., sgd=None):
if isinstance(docs, Doc):
@ -88,9 +91,9 @@ class TokenVectorEncoder(object):
class NeuralTagger(object):
name = 'nn_tagger'
def __init__(self, vocab):
def __init__(self, vocab, model=True):
self.vocab = vocab
self.model = Softmax(self.vocab.morphology.n_tags)
self.model = model
def __call__(self, doc, state=None):
assert state is not None
@ -132,7 +135,7 @@ class NeuralTagger(object):
bp_tokvecs = state['bp_tokvecs']
if self.model.nI is None:
self.model.nI = tokvecs.shape[1]
tag_scores, bp_tag_scores = self.model.begin_update(tokvecs, drop=drop)
loss, d_tag_scores = self.get_loss(docs, golds, tag_scores)
d_tokvecs = bp_tag_scores(d_tag_scores, sgd)
@ -141,7 +144,7 @@ class NeuralTagger(object):
state['bp_tag_scores'] = bp_tag_scores
state['d_tag_scores'] = d_tag_scores
state['tag_loss'] = loss
if 'd_tokvecs' in state:
state['d_tokvecs'] += d_tokvecs
else:
@ -161,6 +164,22 @@ class NeuralTagger(object):
d_scores = scores - to_categorical(correct, nb_classes=scores.shape[1])
return (d_scores**2).sum(), d_scores
def begin_training(self, gold_tuples, pipeline=None):
# Populate tag map, if anything's missing.
tag_map = dict(self.vocab.morphology.tag_map)
for raw_text, annots_brackets in gold_tuples:
for annots, brackets in annots_brackets:
ids, words, tags, heads, deps, ents = annots
for tag in tags:
if tag not in tag_map:
tag_map[tag] = {POS: X}
cdef Vocab vocab = self.vocab
vocab.morphology = Morphology(self.vocab.strings, tag_map,
self.vocab.morphology.lemmatizer)
self.model = Softmax(self.vocab.morphology.n_tags)
cdef class EntityRecognizer(LinearParser):
"""
@ -209,6 +228,28 @@ cdef class NeuralEntityRecognizer(NeuralParser):
name = 'entity'
TransitionSystem = BiluoPushDown
nr_feature = 6
def get_token_ids(self, states):
cdef StateClass state
cdef int n_tokens = 6
ids = numpy.zeros((len(states), n_tokens), dtype='i', order='c')
for i, state in enumerate(states):
ids[i, 0] = state.c.B(0)-1
ids[i, 1] = state.c.B(0)
ids[i, 2] = state.c.B(1)
ids[i, 3] = state.c.E(0)
ids[i, 4] = state.c.E(0)-1
ids[i, 5] = state.c.E(0)+1
for j in range(6):
if ids[i, j] >= state.c.length:
ids[i, j] = -1
if ids[i, j] != -1:
ids[i, j] += state.c.offset
return ids
cdef class BeamDependencyParser(BeamParser):
TransitionSystem = ArcEager

View File

@ -28,6 +28,24 @@ cdef struct LexemeC:
float l2_norm
cdef struct SerializedLexemeC:
unsigned char[4*13 + 8] data
# sizeof(flags_t) # flags
# + sizeof(attr_t) # lang
# + sizeof(attr_t) # id
# + sizeof(attr_t) # length
# + sizeof(attr_t) # orth
# + sizeof(attr_t) # lower
# + sizeof(attr_t) # norm
# + sizeof(attr_t) # shape
# + sizeof(attr_t) # prefix
# + sizeof(attr_t) # suffix
# + sizeof(attr_t) # cluster
# + sizeof(float) # prob
# + sizeof(float) # cluster
# + sizeof(float) # l2_norm
cdef struct Entity:
hash_t id
int start

View File

@ -12,6 +12,7 @@ from libc.math cimport exp
cimport cython
cimport cython.parallel
import cytoolz
import dill
import numpy.random
cimport numpy as np
@ -35,6 +36,7 @@ from thinc.api import layerize, chain
from thinc.neural import Model, Affine, ELU, ReLu, Maxout
from thinc.neural.ops import NumpyOps
from .. import util
from ..util import get_async, get_cuda_stream
from .._ml import zero_init, PrecomputableAffine, PrecomputableMaxouts
from .._ml import Tok2Vec, doc2feats
@ -218,9 +220,8 @@ cdef class Parser:
"""
@classmethod
def Model(cls, nr_class, token_vector_width=128, hidden_width=128, **cfg):
nr_context_tokens = StateClass.nr_context_tokens()
lower = PrecomputableMaxouts(hidden_width,
nF=nr_context_tokens,
nF=cls.nr_feature,
nI=token_vector_width,
pieces=cfg.get('maxout_pieces', 1))
@ -267,7 +268,7 @@ cdef class Parser:
self.model = model
def __reduce__(self):
return (Parser, (self.vocab, self.moves, self.model, self.cfg), None, None)
return (Parser, (self.vocab, self.moves, self.model), None, None)
def __call__(self, Doc tokens, state=None):
"""
@ -392,9 +393,11 @@ cdef class Parser:
lower, stream, drop=dropout)
return state2vec, upper
nr_feature = 13
def get_token_ids(self, states):
cdef StateClass state
cdef int n_tokens = states[0].nr_context_tokens()
cdef int n_tokens = self.nr_feature
ids = numpy.zeros((len(states), n_tokens), dtype='i', order='c')
for i, state in enumerate(states):
state.set_context_tokens(ids[i])
@ -458,6 +461,22 @@ cdef class Parser:
if self.model is True:
self.model = self.Model(self.moves.n_moves, **cfg)
def to_disk(self, path):
path = util.ensure_path(path)
with (path / 'model.bin').open('wb') as file_:
dill.dump(self.model, file_)
def from_disk(self, path):
path = util.ensure_path(path)
with (path / 'model.bin').open('wb') as file_:
self.model = dill.load(file_)
def to_bytes(self):
pass
def from_bytes(self, data):
pass
class ParserStateError(ValueError):
def __init__(self, doc):

View File

@ -19,7 +19,6 @@ def test_pickle_string_store(stringstore, text1, text2):
assert len(stringstore) == len(unpickled)
@pytest.mark.xfail
@pytest.mark.parametrize('text1,text2', [('dog', 'cat')])
def test_pickle_vocab(text1, text2):
vocab = Vocab(lex_attr_getters={int(NORM): lambda string: string[:-1]})

View File

@ -56,3 +56,18 @@ def test_vocab_lexeme_add_flag_provided_id(en_vocab):
assert en_vocab['199'].check_flag(IS_DIGIT) == False
assert en_vocab['the'].check_flag(is_len4) == False
assert en_vocab['dogs'].check_flag(is_len4) == True
def test_lexeme_bytes_roundtrip(en_vocab):
one = en_vocab['one']
alpha = en_vocab['alpha']
assert one.orth != alpha.orth
assert one.lower != alpha.lower
print(one.orth, alpha.orth)
alpha.from_bytes(one.to_bytes())
assert one.orth_ == alpha.orth_
assert one.orth == alpha.orth
assert one.lower == alpha.lower
assert one.lower_ == alpha.lower_

View File

@ -97,11 +97,10 @@ cdef class Tokenizer:
def __reduce__(self):
args = (self.vocab,
self._rules,
self._prefix_re,
self._suffix_re,
self._infix_re,
self.prefix_search,
self.suffix_search,
self.infix_finditer,
self.token_match)
return (self.__class__, args, None, None)
cpdef Doc tokens_from_list(self, list strings):

View File

@ -20,14 +20,16 @@ class Trainer(object):
"""
def __init__(self, nlp, gold_tuples):
self.nlp = nlp
self.gold_tuples = PseudoProjectivity.preprocess_training_data(gold_tuples)
self.nr_epoch = 0
self.optimizer = Adam(NumpyOps(), 0.001)
self.gold_tuples = gold_tuples
def epochs(self, nr_epoch, augment_data=None, gold_preproc=False):
cached_golds = {}
def _epoch(indices):
for i in tqdm.tqdm(indices):
all_docs = []
all_golds = []
for i in indices:
raw_text, paragraph_tuples = self.gold_tuples[i]
if gold_preproc:
raw_text = None
@ -43,7 +45,11 @@ class Trainer(object):
raw_text, paragraph_tuples = augment_data(raw_text, paragraph_tuples)
docs = self.make_docs(raw_text, paragraph_tuples)
golds = self.make_golds(docs, paragraph_tuples)
yield docs, golds
all_docs.extend(docs)
all_golds.extend(golds)
for batch in tqdm.tqdm(partition_all(12, zip(all_docs, all_golds))):
X, y = zip(*batch)
yield X, y
indices = list(range(len(self.gold_tuples)))
for itn in range(nr_epoch):

View File

@ -5,7 +5,7 @@ import bz2
import ujson
import re
from libc.string cimport memset
from libc.string cimport memset, memcpy
from libc.stdint cimport int32_t
from libc.math cimport sqrt
from cymem.cymem cimport Address
@ -13,9 +13,10 @@ from .lexeme cimport EMPTY_LEXEME
from .lexeme cimport Lexeme
from .strings cimport hash_string
from .typedefs cimport attr_t
from .cfile cimport CFile, StringCFile
from .cfile cimport CFile
from .tokens.token cimport Token
from .attrs cimport PROB, LANG
from .structs cimport SerializedLexemeC
from .compat import copy_reg, pickle
from .lemmatizer import Lemmatizer
@ -42,6 +43,7 @@ cdef class Vocab:
def load(cls, path, lex_attr_getters=None, lemmatizer=True,
tag_map=True, oov_prob=True, **deprecated_kwargs):
"""
Deprecated --- replace in spaCy 2
Load the vocabulary from a path.
Arguments:
@ -88,6 +90,7 @@ cdef class Vocab:
self.load_lexemes(path / 'vocab' / 'lexemes.bin')
return self
def __init__(self, lex_attr_getters=None, tag_map=None, lemmatizer=None,
strings=tuple(), **deprecated_kwargs):
"""
@ -149,24 +152,7 @@ cdef class Vocab:
The current number of lexemes stored.
"""
return self.length
def resize_vectors(self, int new_size):
"""
Set vectors_length to a new size, and allocate more memory for the Lexeme
vectors if necessary. The memory will be zeroed.
Arguments:
new_size (int): The new size of the vectors.
"""
cdef hash_t key
cdef size_t addr
if new_size > self.vectors_length:
for key, addr in self._by_hash.items():
lex = <LexemeC*>addr
lex.vector = <float*>self.mem.realloc(lex.vector,
new_size * sizeof(lex.vector[0]))
self.vectors_length = new_size
def add_flag(self, flag_getter, int flag_id=-1):
"""
Set a new boolean flag to words in the vocabulary.
@ -224,7 +210,7 @@ cdef class Vocab:
if lex != NULL:
if lex.orth != self.strings[string]:
raise LookupError.mismatched_strings(
lex.orth, self.strings[string], self.strings[lex.orth], string)
lex.orth, self.strings[string], string)
return lex
else:
return self._new_lexeme(mem, string)
@ -337,148 +323,75 @@ cdef class Vocab:
Token.set_struct_attr(token, attr_id, value)
return tokens
def dump(self, loc=None):
"""
Save the lexemes binary data to the given location, or
return a byte-string with the data if loc is None.
def to_disk(self, path):
path = util.ensure_path(path)
if not path.exists():
path.mkdir()
strings_loc = path / 'strings.json'
with strings_loc.open('w', encoding='utf8') as file_:
self.strings.dump(file_)
self.dump(path / 'lexemes.bin')
Arguments:
loc (Path or None): The path to save to, or None.
"""
cdef CFile fp
if loc is None:
fp = StringCFile('wb')
else:
fp = CFile(loc, 'wb')
cdef size_t st
def from_disk(self, path):
path = util.ensure_path(path)
with (path / 'vocab' / 'strings.json').open('r', encoding='utf8') as file_:
strings_list = ujson.load(file_)
for string in strings_list:
self.strings[string]
self.load_lexemes(path / 'lexemes.bin')
def lexemes_to_bytes(self, **exclude):
cdef hash_t key
cdef size_t addr
cdef hash_t key
cdef LexemeC* lexeme = NULL
cdef SerializedLexemeC lex_data
cdef int size = 0
for key, addr in self._by_hash.items():
if addr == 0:
continue
size += sizeof(lex_data.data)
byte_string = b'\0' * size
byte_ptr = <unsigned char*>byte_string
cdef int j
cdef int i = 0
for key, addr in self._by_hash.items():
if addr == 0:
continue
lexeme = <LexemeC*>addr
fp.write_from(&lexeme.orth, sizeof(lexeme.orth), 1)
fp.write_from(&lexeme.flags, sizeof(lexeme.flags), 1)
fp.write_from(&lexeme.id, sizeof(lexeme.id), 1)
fp.write_from(&lexeme.length, sizeof(lexeme.length), 1)
fp.write_from(&lexeme.orth, sizeof(lexeme.orth), 1)
fp.write_from(&lexeme.lower, sizeof(lexeme.lower), 1)
fp.write_from(&lexeme.norm, sizeof(lexeme.norm), 1)
fp.write_from(&lexeme.shape, sizeof(lexeme.shape), 1)
fp.write_from(&lexeme.prefix, sizeof(lexeme.prefix), 1)
fp.write_from(&lexeme.suffix, sizeof(lexeme.suffix), 1)
fp.write_from(&lexeme.cluster, sizeof(lexeme.cluster), 1)
fp.write_from(&lexeme.prob, sizeof(lexeme.prob), 1)
fp.write_from(&lexeme.sentiment, sizeof(lexeme.sentiment), 1)
fp.write_from(&lexeme.l2_norm, sizeof(lexeme.l2_norm), 1)
fp.write_from(&lexeme.lang, sizeof(lexeme.lang), 1)
fp.close()
if loc is None:
return fp.string_data()
lex_data = Lexeme.c_to_bytes(lexeme)
for j in range(sizeof(lex_data.data)):
byte_ptr[i] = lex_data.data[j]
i += 1
return byte_string
def load_lexemes(self, loc):
def lexemes_from_bytes(self, bytes bytes_data):
"""
Load the binary vocabulary data from the given location.
Arguments:
loc (Path): The path to load from.
Returns:
None
Load the binary vocabulary data from the given string.
"""
fp = CFile(loc, 'rb',
on_open_error=lambda: IOError('LexemeCs file not found at %s' % loc))
cdef LexemeC* lexeme = NULL
cdef LexemeC* lexeme
cdef hash_t key
cdef unicode py_str
cdef attr_t orth = 0
assert sizeof(orth) == sizeof(lexeme.orth)
i = 0
while True:
try:
fp.read_into(&orth, 1, sizeof(orth))
except IOError:
break
lexeme = <LexemeC*>self.mem.alloc(sizeof(LexemeC), 1)
# Copy data from the file into the lexeme
fp.read_into(&lexeme.flags, 1, sizeof(lexeme.flags))
fp.read_into(&lexeme.id, 1, sizeof(lexeme.id))
fp.read_into(&lexeme.length, 1, sizeof(lexeme.length))
fp.read_into(&lexeme.orth, 1, sizeof(lexeme.orth))
fp.read_into(&lexeme.lower, 1, sizeof(lexeme.lower))
fp.read_into(&lexeme.norm, 1, sizeof(lexeme.norm))
fp.read_into(&lexeme.shape, 1, sizeof(lexeme.shape))
fp.read_into(&lexeme.prefix, 1, sizeof(lexeme.prefix))
fp.read_into(&lexeme.suffix, 1, sizeof(lexeme.suffix))
fp.read_into(&lexeme.cluster, 1, sizeof(lexeme.cluster))
fp.read_into(&lexeme.prob, 1, sizeof(lexeme.prob))
fp.read_into(&lexeme.sentiment, 1, sizeof(lexeme.sentiment))
fp.read_into(&lexeme.l2_norm, 1, sizeof(lexeme.l2_norm))
fp.read_into(&lexeme.lang, 1, sizeof(lexeme.lang))
cdef int i = 0
cdef int j = 0
cdef SerializedLexemeC lex_data
chunk_size = sizeof(lex_data.data)
cdef unsigned char* bytes_ptr = bytes_data
for i in range(0, len(bytes_data), chunk_size):
lexeme = <LexemeC*>self.mem.alloc(1, sizeof(LexemeC))
for j in range(sizeof(lex_data.data)):
lex_data.data[j] = bytes_ptr[i+j]
Lexeme.c_from_bytes(lexeme, lex_data)
lexeme.vector = EMPTY_VEC
py_str = self.strings[lexeme.orth]
assert self.strings[py_str] == lexeme.orth, (py_str, lexeme.orth)
key = hash_string(py_str)
self._by_hash.set(key, lexeme)
self._by_orth.set(lexeme.orth, lexeme)
self.length += 1
i += 1
fp.close()
def _deserialize_lexemes(self, CFile fp):
"""
Load the binary vocabulary data from the given CFile.
"""
cdef LexemeC* lexeme = NULL
cdef hash_t key
cdef unicode py_str
cdef attr_t orth = 0
assert sizeof(orth) == sizeof(lexeme.orth)
i = 0
cdef int todo = fp.size
cdef int lex_size = sizeof(lexeme.flags)
lex_size += sizeof(lexeme.id)
lex_size += sizeof(lexeme.length)
lex_size += sizeof(lexeme.orth)
lex_size += sizeof(lexeme.lower)
lex_size += sizeof(lexeme.norm)
lex_size += sizeof(lexeme.shape)
lex_size += sizeof(lexeme.prefix)
lex_size += sizeof(lexeme.suffix)
lex_size += sizeof(lexeme.cluster)
lex_size += sizeof(lexeme.prob)
lex_size += sizeof(lexeme.sentiment)
lex_size += sizeof(lexeme.l2_norm)
lex_size += sizeof(lexeme.lang)
while True:
if todo < lex_size:
break
todo -= lex_size
lexeme = <LexemeC*>self.mem.alloc(sizeof(LexemeC), 1)
# Copy data from the file into the lexeme
fp.read_into(&lexeme.flags, 1, sizeof(lexeme.flags))
fp.read_into(&lexeme.id, 1, sizeof(lexeme.id))
fp.read_into(&lexeme.length, 1, sizeof(lexeme.length))
fp.read_into(&lexeme.orth, 1, sizeof(lexeme.orth))
fp.read_into(&lexeme.lower, 1, sizeof(lexeme.lower))
fp.read_into(&lexeme.norm, 1, sizeof(lexeme.norm))
fp.read_into(&lexeme.shape, 1, sizeof(lexeme.shape))
fp.read_into(&lexeme.prefix, 1, sizeof(lexeme.prefix))
fp.read_into(&lexeme.suffix, 1, sizeof(lexeme.suffix))
fp.read_into(&lexeme.cluster, 1, sizeof(lexeme.cluster))
fp.read_into(&lexeme.prob, 1, sizeof(lexeme.prob))
fp.read_into(&lexeme.sentiment, 1, sizeof(lexeme.sentiment))
fp.read_into(&lexeme.l2_norm, 1, sizeof(lexeme.l2_norm))
fp.read_into(&lexeme.lang, 1, sizeof(lexeme.lang))
lexeme.vector = EMPTY_VEC
py_str = self.strings[lexeme.orth]
key = hash_string(py_str)
self._by_hash.set(key, lexeme)
self._by_orth.set(lexeme.orth, lexeme)
self.length += 1
i += 1
fp.close()
# Deprecated --- delete these once stable
def dump_vectors(self, out_loc):
"""
Save the word vectors to a binary file.
@ -487,7 +400,7 @@ cdef class Vocab:
loc (Path): The path to save to.
Returns:
None
"""
#"""
cdef int32_t vec_len = self.vectors_length
cdef int32_t word_len
cdef bytes word_str
@ -508,6 +421,8 @@ cdef class Vocab:
out_file.write_from(vec, vec_len, sizeof(float))
out_file.close()
def load_vectors(self, file_):
"""
Load vectors from a text-based file.
@ -610,38 +525,22 @@ cdef class Vocab:
return vec_len
def pickle_vocab(vocab):
sstore = vocab.strings
morph = vocab.morphology
length = vocab.length
data_dir = vocab.data_dir
lex_attr_getters = vocab.lex_attr_getters
def resize_vectors(self, int new_size):
"""
Set vectors_length to a new size, and allocate more memory for the Lexeme
vectors if necessary. The memory will be zeroed.
lexemes_data = vocab.dump()
vectors_length = vocab.vectors_length
return (unpickle_vocab,
(sstore, morph, data_dir, lex_attr_getters,
lexemes_data, length, vectors_length))
def unpickle_vocab(sstore, morphology, data_dir,
lex_attr_getters, bytes lexemes_data, int length, int vectors_length):
cdef Vocab vocab = Vocab()
vocab.length = length
vocab.vectors_length = vectors_length
vocab.strings = sstore
cdef CFile fp = StringCFile('r', data=lexemes_data)
vocab.morphology = morphology
vocab.data_dir = data_dir
vocab.lex_attr_getters = lex_attr_getters
vocab._deserialize_lexemes(fp)
vocab.length = length
vocab.vectors_length = vectors_length
return vocab
copy_reg.pickle(Vocab, pickle_vocab, unpickle_vocab)
Arguments:
new_size (int): The new size of the vectors.
"""
cdef hash_t key
cdef size_t addr
if new_size > self.vectors_length:
for key, addr in self._by_hash.items():
lex = <LexemeC*>addr
lex.vector = <float*>self.mem.realloc(lex.vector,
new_size * sizeof(lex.vector[0]))
self.vectors_length = new_size
def write_binary_vectors(in_loc, out_loc):
@ -670,6 +569,39 @@ def write_binary_vectors(in_loc, out_loc):
out_file.write_from(vec, vec_len, sizeof(float))
def pickle_vocab(vocab):
sstore = vocab.strings
morph = vocab.morphology
length = vocab.length
data_dir = vocab.data_dir
lex_attr_getters = vocab.lex_attr_getters
lexemes_data = vocab.lexemes_to_bytes()
vectors_length = vocab.vectors_length
return (unpickle_vocab,
(sstore, morph, data_dir, lex_attr_getters,
lexemes_data, length, vectors_length))
def unpickle_vocab(sstore, morphology, data_dir,
lex_attr_getters, bytes lexemes_data, int length, int vectors_length):
cdef Vocab vocab = Vocab()
vocab.length = length
vocab.vectors_length = vectors_length
vocab.strings = sstore
vocab.morphology = morphology
vocab.data_dir = data_dir
vocab.lex_attr_getters = lex_attr_getters
vocab.lexemes_from_bytes(lexemes_data)
vocab.length = length
vocab.vectors_length = vectors_length
return vocab
copy_reg.pickle(Vocab, pickle_vocab, unpickle_vocab)
class LookupError(Exception):
@classmethod
def mismatched_strings(cls, id_, id_string, original_string):
@ -701,3 +633,237 @@ class VectorReadError(Exception):
"Vector size: %d\n"
"Max size: %d\n"
"Min size: 1\n" % (loc, size, MAX_VEC_SIZE))
#
#Deprecated --- delete these once stable
#
# def dump_vectors(self, out_loc):
# """
# Save the word vectors to a binary file.
#
# Arguments:
# loc (Path): The path to save to.
# Returns:
# None
# #"""
# cdef int32_t vec_len = self.vectors_length
# cdef int32_t word_len
# cdef bytes word_str
# cdef char* chars
#
# cdef Lexeme lexeme
# cdef CFile out_file = CFile(out_loc, 'wb')
# for lexeme in self:
# word_str = lexeme.orth_.encode('utf8')
# vec = lexeme.c.vector
# word_len = len(word_str)
#
# out_file.write_from(&word_len, 1, sizeof(word_len))
# out_file.write_from(&vec_len, 1, sizeof(vec_len))
#
# chars = <char*>word_str
# out_file.write_from(chars, word_len, sizeof(char))
# out_file.write_from(vec, vec_len, sizeof(float))
# out_file.close()
#
#
#
# def load_vectors(self, file_):
# """
# Load vectors from a text-based file.
#
# Arguments:
# file_ (buffer): The file to read from. Entries should be separated by newlines,
# and each entry should be whitespace delimited. The first value of the entry
# should be the word string, and subsequent entries should be the values of the
# vector.
#
# Returns:
# vec_len (int): The length of the vectors loaded.
# """
# cdef LexemeC* lexeme
# cdef attr_t orth
# cdef int32_t vec_len = -1
# cdef double norm = 0.0
#
# whitespace_pattern = re.compile(r'\s', re.UNICODE)
#
# for line_num, line in enumerate(file_):
# pieces = line.split()
# word_str = " " if whitespace_pattern.match(line) else pieces.pop(0)
# if vec_len == -1:
# vec_len = len(pieces)
# elif vec_len != len(pieces):
# raise VectorReadError.mismatched_sizes(file_, line_num,
# vec_len, len(pieces))
# orth = self.strings[word_str]
# lexeme = <LexemeC*><void*>self.get_by_orth(self.mem, orth)
# lexeme.vector = <float*>self.mem.alloc(vec_len, sizeof(float))
# for i, val_str in enumerate(pieces):
# lexeme.vector[i] = float(val_str)
# norm = 0.0
# for i in range(vec_len):
# norm += lexeme.vector[i] * lexeme.vector[i]
# lexeme.l2_norm = sqrt(norm)
# self.vectors_length = vec_len
# return vec_len
#
# def load_vectors_from_bin_loc(self, loc):
# """
# Load vectors from the location of a binary file.
#
# Arguments:
# loc (unicode): The path of the binary file to load from.
#
# Returns:
# vec_len (int): The length of the vectors loaded.
# """
# cdef CFile file_ = CFile(loc, b'rb')
# cdef int32_t word_len
# cdef int32_t vec_len = 0
# cdef int32_t prev_vec_len = 0
# cdef float* vec
# cdef Address mem
# cdef attr_t string_id
# cdef bytes py_word
# cdef vector[float*] vectors
# cdef int line_num = 0
# cdef Pool tmp_mem = Pool()
# while True:
# try:
# file_.read_into(&word_len, sizeof(word_len), 1)
# except IOError:
# break
# file_.read_into(&vec_len, sizeof(vec_len), 1)
# if prev_vec_len != 0 and vec_len != prev_vec_len:
# raise VectorReadError.mismatched_sizes(loc, line_num,
# vec_len, prev_vec_len)
# if 0 >= vec_len >= MAX_VEC_SIZE:
# raise VectorReadError.bad_size(loc, vec_len)
#
# chars = <char*>file_.alloc_read(tmp_mem, word_len, sizeof(char))
# vec = <float*>file_.alloc_read(self.mem, vec_len, sizeof(float))
#
# string_id = self.strings[chars[:word_len]]
# # Insert words into vocab to add vector.
# self.get_by_orth(self.mem, string_id)
# while string_id >= vectors.size():
# vectors.push_back(EMPTY_VEC)
# assert vec != NULL
# vectors[string_id] = vec
# line_num += 1
# cdef LexemeC* lex
# cdef size_t lex_addr
# cdef double norm = 0.0
# cdef int i
# for orth, lex_addr in self._by_orth.items():
# lex = <LexemeC*>lex_addr
# if lex.lower < vectors.size():
# lex.vector = vectors[lex.lower]
# norm = 0.0
# for i in range(vec_len):
# norm += lex.vector[i] * lex.vector[i]
# lex.l2_norm = sqrt(norm)
# else:
# lex.vector = EMPTY_VEC
# self.vectors_length = vec_len
# return vec_len
#
#
#def write_binary_vectors(in_loc, out_loc):
# cdef CFile out_file = CFile(out_loc, 'wb')
# cdef Address mem
# cdef int32_t word_len
# cdef int32_t vec_len
# cdef char* chars
# with bz2.BZ2File(in_loc, 'r') as file_:
# for line in file_:
# pieces = line.split()
# word = pieces.pop(0)
# mem = Address(len(pieces), sizeof(float))
# vec = <float*>mem.ptr
# for i, val_str in enumerate(pieces):
# vec[i] = float(val_str)
#
# word_len = len(word)
# vec_len = len(pieces)
#
# out_file.write_from(&word_len, 1, sizeof(word_len))
# out_file.write_from(&vec_len, 1, sizeof(vec_len))
#
# chars = <char*>word
# out_file.write_from(chars, len(word), sizeof(char))
# out_file.write_from(vec, vec_len, sizeof(float))
#
#
# def resize_vectors(self, int new_size):
# """
# Set vectors_length to a new size, and allocate more memory for the Lexeme
# vectors if necessary. The memory will be zeroed.
#
# Arguments:
# new_size (int): The new size of the vectors.
# """
# cdef hash_t key
# cdef size_t addr
# if new_size > self.vectors_length:
# for key, addr in self._by_hash.items():
# lex = <LexemeC*>addr
# lex.vector = <float*>self.mem.realloc(lex.vector,
# new_size * sizeof(lex.vector[0]))
# self.vectors_length = new_size
#
#
#
# def dump(self, loc=None):
# """
# Save the lexemes binary data to the given location, or
# return a byte-string with the data if loc is None.
#
# Arguments:
# loc (Path or None): The path to save to, or None.
# """
# if loc is None:
# return self.to_bytes()
# else:
# return self.to_disk(loc)
#
# def load_lexemes(self, loc):
# """
# Load the binary vocabulary data from the given location.
#
# Arguments:
# loc (Path): The path to load from.
#
# Returns:
# None
# """
# fp = CFile(loc, 'rb',
# on_open_error=lambda: IOError('LexemeCs file not found at %s' % loc))
# cdef LexemeC* lexeme = NULL
# cdef SerializedLexemeC lex_data
# cdef hash_t key
# cdef unicode py_str
# cdef attr_t orth = 0
# assert sizeof(orth) == sizeof(lexeme.orth)
# i = 0
# while True:
# try:
# fp.read_into(&orth, 1, sizeof(orth))
# except IOError:
# break
# lexeme = <LexemeC*>self.mem.alloc(sizeof(LexemeC), 1)
# # Copy data from the file into the lexeme
# fp.read_into(&lex_data.data, 1, sizeof(lex_data.data))
# Lexeme.c_from_bytes(lexeme, lex_data)
#
# lexeme.vector = EMPTY_VEC
# py_str = self.strings[lexeme.orth]
# key = hash_string(py_str)
# self._by_hash.set(key, lexeme)
# self._by_orth.set(lexeme.orth, lexeme)
# self.length += 1
# i += 1
# fp.close()