mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 17:06:29 +03:00
Get spaCy train command working with neural network
* Integrate models into pipeline * Add basic serialization (maybe incorrect) * Fix pickle on vocab
This commit is contained in:
parent
3bf4a28d8d
commit
793430aa7a
|
@ -81,17 +81,19 @@ class CLI(object):
|
|||
train_data=("location of JSON-formatted training data", "positional", None, str),
|
||||
dev_data=("location of JSON-formatted development data (optional)", "positional", None, str),
|
||||
n_iter=("number of iterations", "option", "n", int),
|
||||
nsents=("number of sentences", "option", None, int),
|
||||
parser_L1=("L1 regularization penalty for parser", "option", "L", float),
|
||||
no_tagger=("Don't train tagger", "flag", "T", bool),
|
||||
no_parser=("Don't train parser", "flag", "P", bool),
|
||||
no_ner=("Don't train NER", "flag", "N", bool)
|
||||
)
|
||||
def train(self, lang, output_dir, train_data, dev_data=None, n_iter=15,
|
||||
parser_L1=0.0, no_tagger=False, no_parser=False, no_ner=False):
|
||||
nsents=0, parser_L1=0.0, no_tagger=False, no_parser=False, no_ner=False):
|
||||
"""
|
||||
Train a model. Expects data in spaCy's JSON format.
|
||||
"""
|
||||
cli_train(lang, output_dir, train_data, dev_data, n_iter, not no_tagger,
|
||||
nsents = nsents or None
|
||||
cli_train(lang, output_dir, train_data, dev_data, n_iter, nsents, not no_tagger,
|
||||
not no_parser, not no_ner, parser_L1)
|
||||
|
||||
@plac.annotations(
|
||||
|
|
|
@ -3,9 +3,11 @@ from cymem.cymem cimport Pool
|
|||
|
||||
cdef class CFile:
|
||||
cdef FILE* fp
|
||||
cdef bint is_open
|
||||
cdef unsigned char* data
|
||||
cdef int is_open
|
||||
cdef Pool mem
|
||||
cdef int size # For compatibility with subclass
|
||||
cdef int i # For compatibility with subclass
|
||||
cdef int _capacity # For compatibility with subclass
|
||||
|
||||
cdef int read_into(self, void* dest, size_t number, size_t elem_size) except -1
|
||||
|
@ -16,8 +18,13 @@ cdef class CFile:
|
|||
|
||||
|
||||
|
||||
cdef class StringCFile(CFile):
|
||||
cdef class StringCFile:
|
||||
cdef unsigned char* data
|
||||
cdef int is_open
|
||||
cdef Pool mem
|
||||
cdef int size # For compatibility with subclass
|
||||
cdef int i # For compatibility with subclass
|
||||
cdef int _capacity # For compatibility with subclass
|
||||
|
||||
cdef int read_into(self, void* dest, size_t number, size_t elem_size) except -1
|
||||
|
||||
|
|
|
@ -53,31 +53,43 @@ cdef class CFile:
|
|||
|
||||
|
||||
cdef class StringCFile:
|
||||
def __init__(self, mode, bytes data=b'', on_open_error=None):
|
||||
def __init__(self, bytes data, mode, on_open_error=None):
|
||||
self.mem = Pool()
|
||||
self.is_open = 'w' in mode
|
||||
self.is_open = 1 if 'w' in mode else 0
|
||||
self._capacity = max(len(data), 8)
|
||||
self.size = len(data)
|
||||
self.i = 0
|
||||
self.data = <unsigned char*>self.mem.alloc(1, self._capacity)
|
||||
for i in range(len(data)):
|
||||
self.data[i] = data[i]
|
||||
|
||||
def __dealloc__(self):
|
||||
# Important to override this -- or
|
||||
# we try to close a non-existant file pointer!
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
self.is_open = False
|
||||
|
||||
def string_data(self):
|
||||
return (self.data-self.size)[:self.size]
|
||||
cdef bytes byte_string = b'\0' * (self.size)
|
||||
bytes_ptr = <char*>byte_string
|
||||
for i in range(self.size):
|
||||
bytes_ptr[i] = self.data[i]
|
||||
print(byte_string)
|
||||
return byte_string
|
||||
|
||||
cdef int read_into(self, void* dest, size_t number, size_t elem_size) except -1:
|
||||
memcpy(dest, self.data, elem_size * number)
|
||||
self.data += elem_size * number
|
||||
if self.i+(number * elem_size) < self.size:
|
||||
memcpy(dest, &self.data[self.i], elem_size * number)
|
||||
self.i += elem_size * number
|
||||
|
||||
cdef int write_from(self, void* src, size_t elem_size, size_t number) except -1:
|
||||
write_size = number * elem_size
|
||||
if (self.size + write_size) >= self._capacity:
|
||||
self._capacity = (self.size + write_size) * 2
|
||||
self.data = <unsigned char*>self.mem.realloc(self.data, self._capacity)
|
||||
memcpy(&self.data[self.size], src, elem_size * number)
|
||||
memcpy(&self.data[self.size], src, write_size)
|
||||
self.size += write_size
|
||||
|
||||
cdef void* alloc_read(self, Pool mem, size_t number, size_t elem_size) except *:
|
||||
|
|
|
@ -4,16 +4,20 @@ from __future__ import unicode_literals, division, print_function
|
|||
import json
|
||||
from collections import defaultdict
|
||||
import cytoolz
|
||||
from pathlib import Path
|
||||
import dill
|
||||
|
||||
from ..tokens.doc import Doc
|
||||
from ..scorer import Scorer
|
||||
from ..gold import GoldParse, merge_sents
|
||||
from ..gold import read_json_file as read_gold_json
|
||||
from ..util import prints
|
||||
from .. import util
|
||||
from .. import displacy
|
||||
|
||||
|
||||
def train(language, output_dir, train_data, dev_data, n_iter, tagger, parser, ner,
|
||||
parser_L1):
|
||||
def train(language, output_dir, train_data, dev_data, n_iter, n_sents,
|
||||
tagger, parser, ner, parser_L1):
|
||||
output_path = util.ensure_path(output_dir)
|
||||
train_path = util.ensure_path(train_data)
|
||||
dev_path = util.ensure_path(dev_data)
|
||||
|
@ -39,10 +43,8 @@ def train(language, output_dir, train_data, dev_data, n_iter, tagger, parser, ne
|
|||
'n_iter': n_iter,
|
||||
'lang': language,
|
||||
'features': lang.Defaults.tagger_features}
|
||||
gold_train = list(read_gold_json(train_path))[:100]
|
||||
gold_dev = list(read_gold_json(dev_path)) if dev_path else None
|
||||
|
||||
gold_dev = gold_dev[:100]
|
||||
gold_train = list(read_gold_json(train_path, limit=n_sents))
|
||||
gold_dev = list(read_gold_json(dev_path, limit=n_sents)) if dev_path else None
|
||||
|
||||
train_model(lang, gold_train, gold_dev, output_path, n_iter)
|
||||
if gold_dev:
|
||||
|
@ -63,34 +65,48 @@ def train_config(config):
|
|||
def train_model(Language, train_data, dev_data, output_path, n_iter, **cfg):
|
||||
print("Itn.\tDep. Loss\tUAS\tNER F.\tTag %\tToken %")
|
||||
|
||||
nlp = Language(pipeline=['token_vectors', 'tags', 'dependencies', 'entities'])
|
||||
nlp = Language(pipeline=['token_vectors', 'tags', 'dependencies'])
|
||||
|
||||
# TODO: Get spaCy using Thinc's trainer and optimizer
|
||||
with nlp.begin_training(train_data, **cfg) as (trainer, optimizer):
|
||||
for itn, epoch in enumerate(trainer.epochs(n_iter)):
|
||||
for itn, epoch in enumerate(trainer.epochs(n_iter, gold_preproc=True)):
|
||||
losses = defaultdict(float)
|
||||
for docs, golds in epoch:
|
||||
to_render = []
|
||||
for i, (docs, golds) in enumerate(epoch):
|
||||
state = nlp.update(docs, golds, drop=0., sgd=optimizer)
|
||||
losses['dep_loss'] += state.get('parser_loss', 0.0)
|
||||
to_render.insert(0, nlp(docs[-1].text))
|
||||
to_render[0].user_data['title'] = "Batch %d" % i
|
||||
with Path('/tmp/entities.html').open('w') as file_:
|
||||
html = displacy.render(to_render[:5], style='ent', page=True,
|
||||
options={'compact': True})
|
||||
file_.write(html)
|
||||
with Path('/tmp/parses.html').open('w') as file_:
|
||||
html = displacy.render(to_render[:5], style='dep', page=True,
|
||||
options={'compact': True})
|
||||
file_.write(html)
|
||||
if dev_data:
|
||||
dev_scores = trainer.evaluate(dev_data).scores
|
||||
else:
|
||||
dev_scores = defaultdict(float)
|
||||
print_progress(itn, losses, dev_scores)
|
||||
with (output_path / 'model.bin').open('wb') as file_:
|
||||
dill.dump(nlp, file_, -1)
|
||||
#nlp.to_disk(output_path, tokenizer=False)
|
||||
|
||||
|
||||
def evaluate(Language, gold_tuples, output_path):
|
||||
print("Load parser", output_path)
|
||||
nlp = Language(path=output_path)
|
||||
def evaluate(Language, gold_tuples, path):
|
||||
with (path / 'model.bin').open('rb') as file_:
|
||||
nlp = dill.load(file_)
|
||||
scorer = Scorer()
|
||||
for raw_text, sents in gold_tuples:
|
||||
sents = merge_sents(sents)
|
||||
for annot_tuples, brackets in sents:
|
||||
if raw_text is None:
|
||||
tokens = nlp.tokenizer.tokens_from_list(annot_tuples[1])
|
||||
nlp.tagger(tokens)
|
||||
nlp.parser(tokens)
|
||||
nlp.entity(tokens)
|
||||
tokens = Doc(nlp.vocab, words=annot_tuples[1])
|
||||
state = None
|
||||
for proc in nlp.pipeline:
|
||||
state = proc(tokens, state=state)
|
||||
else:
|
||||
tokens = nlp(raw_text)
|
||||
gold = GoldParse.from_annot_tuples(tokens, annot_tuples)
|
||||
|
|
|
@ -138,14 +138,16 @@ def _min_edit_path(cand_words, gold_words):
|
|||
return prev_costs[n_gold], previous_row[-1]
|
||||
|
||||
|
||||
def read_json_file(loc, docs_filter=None):
|
||||
def read_json_file(loc, docs_filter=None, make_supertags=False, limit=None):
|
||||
loc = ensure_path(loc)
|
||||
if loc.is_dir():
|
||||
for filename in loc.iterdir():
|
||||
yield from read_json_file(loc / filename)
|
||||
yield from read_json_file(loc / filename, limit=limit)
|
||||
else:
|
||||
with loc.open('r', encoding='utf8') as file_:
|
||||
docs = ujson.load(file_)
|
||||
if limit is not None:
|
||||
docs = docs[:limit]
|
||||
for doc in docs:
|
||||
if docs_filter is not None and not docs_filter(doc):
|
||||
continue
|
||||
|
@ -169,11 +171,13 @@ def read_json_file(loc, docs_filter=None):
|
|||
if labels[-1].lower() == 'root':
|
||||
labels[-1] = 'ROOT'
|
||||
ner.append(token.get('ner', '-'))
|
||||
sents.append((
|
||||
(ids, words, tags, heads, labels, ner),
|
||||
sent.get('brackets', [])))
|
||||
if make_supertags:
|
||||
tags[-1] = '-'.join((tags[-1], labels[-1], ner[-1]))
|
||||
sents.append([
|
||||
[ids, words, tags, heads, labels, ner],
|
||||
sent.get('brackets', [])])
|
||||
if sents:
|
||||
yield (paragraph.get('raw', None), sents)
|
||||
yield [paragraph.get('raw', None), sents]
|
||||
|
||||
|
||||
def _iob_to_biluo(tags):
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# coding: utf8
|
||||
from __future__ import absolute_import, unicode_literals
|
||||
from contextlib import contextmanager
|
||||
import dill
|
||||
|
||||
from .tokenizer import Tokenizer
|
||||
from .vocab import Vocab
|
||||
|
@ -188,10 +189,18 @@ class Language(object):
|
|||
|
||||
@contextmanager
|
||||
def begin_training(self, gold_tuples, **cfg):
|
||||
# Populate vocab
|
||||
for _, annots_brackets in gold_tuples:
|
||||
for annots, _ in annots_brackets:
|
||||
for word in annots[1]:
|
||||
_ = self.vocab[word]
|
||||
# Handle crossing dependencies
|
||||
gold_tuples = PseudoProjectivity.preprocess_training_data(gold_tuples)
|
||||
contexts = []
|
||||
for proc in self.pipeline:
|
||||
if hasattr(proc, 'begin_training'):
|
||||
context = proc.begin_training(gold_tuples, pipeline=self.pipeline)
|
||||
context = proc.begin_training(gold_tuples,
|
||||
pipeline=self.pipeline)
|
||||
contexts.append(context)
|
||||
trainer = Trainer(self, gold_tuples, **cfg)
|
||||
yield trainer, trainer.optimizer
|
||||
|
@ -221,15 +230,72 @@ class Language(object):
|
|||
for doc, state in stream:
|
||||
yield doc
|
||||
|
||||
def to_disk(self, path):
|
||||
raise NotImplemented
|
||||
def to_disk(self, path, **exclude):
|
||||
"""Save the current state to a directory.
|
||||
|
||||
def from_disk(self, path):
|
||||
raise NotImplemented
|
||||
Args:
|
||||
path: A path to a directory, which will be created if it doesn't
|
||||
exist. Paths may be either strings or pathlib.Path-like
|
||||
objects.
|
||||
**exclude: Prevent named attributes from being saved.
|
||||
"""
|
||||
path = util.ensure_path(path)
|
||||
if not path.exists():
|
||||
path.mkdir()
|
||||
if not path.is_dir():
|
||||
raise IOError("Output path must be a directory")
|
||||
props = {}
|
||||
for name, value in self.__dict__.items():
|
||||
if name in exclude:
|
||||
continue
|
||||
if hasattr(value, 'to_disk'):
|
||||
value.to_disk(path / name)
|
||||
else:
|
||||
props[name] = value
|
||||
with (path / 'props.pickle').open('wb') as file_:
|
||||
dill.dump(props, file_)
|
||||
|
||||
def to_bytes(self, path):
|
||||
raise NotImplemented
|
||||
def from_disk(self, path, **exclude):
|
||||
"""Load the current state from a directory.
|
||||
|
||||
def from_bytes(self, path):
|
||||
raise NotImplemented
|
||||
Args:
|
||||
path: A path to a directory. Paths may be either strings or
|
||||
pathlib.Path-like objects.
|
||||
**exclude: Prevent named attributes from being saved.
|
||||
"""
|
||||
path = util.ensure_path(path)
|
||||
for name in path.iterdir():
|
||||
if name not in exclude and hasattr(self, str(name)):
|
||||
getattr(self, name).from_disk(path / name)
|
||||
with (path / 'props.pickle').open('rb') as file_:
|
||||
bytes_data = file_.read()
|
||||
self.from_bytes(bytes_data, **exclude)
|
||||
return self
|
||||
|
||||
def to_bytes(self, **exclude):
|
||||
"""Serialize the current state to a binary string.
|
||||
|
||||
Args:
|
||||
path: A path to a directory. Paths may be either strings or
|
||||
pathlib.Path-like objects.
|
||||
**exclude: Prevent named attributes from being serialized.
|
||||
"""
|
||||
props = dict(self.__dict__)
|
||||
for key in exclude:
|
||||
if key in props:
|
||||
props.pop(key)
|
||||
return dill.dumps(props, -1)
|
||||
|
||||
def from_bytes(self, bytes_data, **exclude):
|
||||
"""Load state from a binary string.
|
||||
|
||||
Args:
|
||||
bytes_data (bytes): The data to load from.
|
||||
**exclude: Prevent named attributes from being loaded.
|
||||
"""
|
||||
props = dill.loads(bytes_data)
|
||||
for key, value in props.items():
|
||||
if key not in exclude:
|
||||
setattr(self, key, value)
|
||||
return self
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@ from .typedefs cimport attr_t, hash_t, flags_t, len_t, tag_t
|
|||
from .attrs cimport attr_id_t
|
||||
from .attrs cimport ID, ORTH, LOWER, NORM, SHAPE, PREFIX, SUFFIX, LENGTH, CLUSTER, LANG
|
||||
|
||||
from .structs cimport LexemeC
|
||||
from .structs cimport LexemeC, SerializedLexemeC
|
||||
from .strings cimport StringStore
|
||||
from .vocab cimport Vocab
|
||||
|
||||
|
@ -22,7 +22,23 @@ cdef class Lexeme:
|
|||
self.c = lex
|
||||
self.vocab = vocab
|
||||
self.orth = lex.orth
|
||||
|
||||
|
||||
@staticmethod
|
||||
cdef inline SerializedLexemeC c_to_bytes(const LexemeC* lex) nogil:
|
||||
cdef SerializedLexemeC lex_data
|
||||
buff = <const unsigned char*>&lex.flags
|
||||
end = <const unsigned char*>&lex.l2_norm + sizeof(lex.l2_norm)
|
||||
for i in range(sizeof(lex_data.data)):
|
||||
lex_data.data[i] = buff[i]
|
||||
return lex_data
|
||||
|
||||
@staticmethod
|
||||
cdef inline void c_from_bytes(LexemeC* lex, SerializedLexemeC lex_data) nogil:
|
||||
buff = <unsigned char*>&lex.flags
|
||||
end = <unsigned char*>&lex.l2_norm + sizeof(lex.l2_norm)
|
||||
for i in range(sizeof(lex_data.data)):
|
||||
buff[i] = lex_data.data[i]
|
||||
|
||||
@staticmethod
|
||||
cdef inline void set_struct_attr(LexemeC* lex, attr_id_t name, attr_t value) nogil:
|
||||
if name < (sizeof(flags_t) * 8):
|
||||
|
|
|
@ -116,6 +116,29 @@ cdef class Lexeme:
|
|||
return 0.0
|
||||
return numpy.dot(self.vector, other.vector) / (self.vector_norm * other.vector_norm)
|
||||
|
||||
def to_bytes(self):
|
||||
lex_data = Lexeme.c_to_bytes(self.c)
|
||||
start = <const char*>&self.c.flags
|
||||
end = <const char*>&self.c.l2_norm + sizeof(self.c.l2_norm)
|
||||
assert (end-start) == sizeof(lex_data.data), (end-start, sizeof(lex_data.data))
|
||||
byte_string = b'\0' * sizeof(lex_data.data)
|
||||
byte_chars = <char*>byte_string
|
||||
for i in range(sizeof(lex_data.data)):
|
||||
byte_chars[i] = lex_data.data[i]
|
||||
assert len(byte_string) == sizeof(lex_data.data), (len(byte_string),
|
||||
sizeof(lex_data.data))
|
||||
return byte_string
|
||||
|
||||
def from_bytes(self, bytes byte_string):
|
||||
# This method doesn't really have a use-case --- wrote it for testing.
|
||||
# Possibly delete? It puts the Lexeme out of synch with the vocab.
|
||||
cdef SerializedLexemeC lex_data
|
||||
assert len(byte_string) == sizeof(lex_data.data)
|
||||
for i in range(len(byte_string)):
|
||||
lex_data.data[i] = byte_string[i]
|
||||
Lexeme.c_from_bytes(self.c, lex_data)
|
||||
self.orth = self.c.orth
|
||||
|
||||
property has_vector:
|
||||
def __get__(self):
|
||||
cdef int i
|
||||
|
|
|
@ -26,10 +26,14 @@ from .syntax.beam_parser cimport BeamParser
|
|||
from .syntax.ner cimport BiluoPushDown
|
||||
from .syntax.arc_eager cimport ArcEager
|
||||
from .tagger import Tagger
|
||||
from .syntax.stateclass cimport StateClass
|
||||
from .gold cimport GoldParse
|
||||
from .morphology cimport Morphology
|
||||
from .vocab cimport Vocab
|
||||
|
||||
from .attrs import ID, LOWER, PREFIX, SUFFIX, SHAPE, TAG, DEP
|
||||
from .attrs import ID, LOWER, PREFIX, SUFFIX, SHAPE, TAG, DEP, POS
|
||||
from ._ml import Tok2Vec, flatten, get_col, doc2feats
|
||||
from .parts_of_speech import X
|
||||
|
||||
|
||||
class TokenVectorEncoder(object):
|
||||
|
@ -50,7 +54,7 @@ class TokenVectorEncoder(object):
|
|||
docs = [docs]
|
||||
tokvecs = self.predict(docs)
|
||||
self.set_annotations(docs, tokvecs)
|
||||
state = {} if state is not None else state
|
||||
state = {} if state is None else state
|
||||
state['tokvecs'] = tokvecs
|
||||
return state
|
||||
|
||||
|
@ -58,7 +62,6 @@ class TokenVectorEncoder(object):
|
|||
raise NotImplementedError
|
||||
|
||||
def predict(self, docs):
|
||||
cdef Doc doc
|
||||
feats = self.doc2feats(docs)
|
||||
tokvecs = self.model(feats)
|
||||
return tokvecs
|
||||
|
@ -68,7 +71,7 @@ class TokenVectorEncoder(object):
|
|||
for doc in docs:
|
||||
doc.tensor = tokvecs[start : start + len(doc)]
|
||||
start += len(doc)
|
||||
|
||||
|
||||
def update(self, docs, golds, state=None,
|
||||
drop=0., sgd=None):
|
||||
if isinstance(docs, Doc):
|
||||
|
@ -88,9 +91,9 @@ class TokenVectorEncoder(object):
|
|||
|
||||
class NeuralTagger(object):
|
||||
name = 'nn_tagger'
|
||||
def __init__(self, vocab):
|
||||
def __init__(self, vocab, model=True):
|
||||
self.vocab = vocab
|
||||
self.model = Softmax(self.vocab.morphology.n_tags)
|
||||
self.model = model
|
||||
|
||||
def __call__(self, doc, state=None):
|
||||
assert state is not None
|
||||
|
@ -132,7 +135,7 @@ class NeuralTagger(object):
|
|||
bp_tokvecs = state['bp_tokvecs']
|
||||
if self.model.nI is None:
|
||||
self.model.nI = tokvecs.shape[1]
|
||||
|
||||
|
||||
tag_scores, bp_tag_scores = self.model.begin_update(tokvecs, drop=drop)
|
||||
loss, d_tag_scores = self.get_loss(docs, golds, tag_scores)
|
||||
d_tokvecs = bp_tag_scores(d_tag_scores, sgd)
|
||||
|
@ -141,7 +144,7 @@ class NeuralTagger(object):
|
|||
state['bp_tag_scores'] = bp_tag_scores
|
||||
state['d_tag_scores'] = d_tag_scores
|
||||
state['tag_loss'] = loss
|
||||
|
||||
|
||||
if 'd_tokvecs' in state:
|
||||
state['d_tokvecs'] += d_tokvecs
|
||||
else:
|
||||
|
@ -161,6 +164,22 @@ class NeuralTagger(object):
|
|||
d_scores = scores - to_categorical(correct, nb_classes=scores.shape[1])
|
||||
return (d_scores**2).sum(), d_scores
|
||||
|
||||
def begin_training(self, gold_tuples, pipeline=None):
|
||||
# Populate tag map, if anything's missing.
|
||||
tag_map = dict(self.vocab.morphology.tag_map)
|
||||
for raw_text, annots_brackets in gold_tuples:
|
||||
for annots, brackets in annots_brackets:
|
||||
ids, words, tags, heads, deps, ents = annots
|
||||
for tag in tags:
|
||||
if tag not in tag_map:
|
||||
tag_map[tag] = {POS: X}
|
||||
|
||||
cdef Vocab vocab = self.vocab
|
||||
vocab.morphology = Morphology(self.vocab.strings, tag_map,
|
||||
self.vocab.morphology.lemmatizer)
|
||||
self.model = Softmax(self.vocab.morphology.n_tags)
|
||||
|
||||
|
||||
|
||||
cdef class EntityRecognizer(LinearParser):
|
||||
"""
|
||||
|
@ -209,6 +228,28 @@ cdef class NeuralEntityRecognizer(NeuralParser):
|
|||
name = 'entity'
|
||||
TransitionSystem = BiluoPushDown
|
||||
|
||||
nr_feature = 6
|
||||
|
||||
def get_token_ids(self, states):
|
||||
cdef StateClass state
|
||||
cdef int n_tokens = 6
|
||||
ids = numpy.zeros((len(states), n_tokens), dtype='i', order='c')
|
||||
for i, state in enumerate(states):
|
||||
ids[i, 0] = state.c.B(0)-1
|
||||
ids[i, 1] = state.c.B(0)
|
||||
ids[i, 2] = state.c.B(1)
|
||||
ids[i, 3] = state.c.E(0)
|
||||
ids[i, 4] = state.c.E(0)-1
|
||||
ids[i, 5] = state.c.E(0)+1
|
||||
for j in range(6):
|
||||
if ids[i, j] >= state.c.length:
|
||||
ids[i, j] = -1
|
||||
if ids[i, j] != -1:
|
||||
ids[i, j] += state.c.offset
|
||||
return ids
|
||||
|
||||
|
||||
|
||||
|
||||
cdef class BeamDependencyParser(BeamParser):
|
||||
TransitionSystem = ArcEager
|
||||
|
|
|
@ -28,6 +28,24 @@ cdef struct LexemeC:
|
|||
float l2_norm
|
||||
|
||||
|
||||
cdef struct SerializedLexemeC:
|
||||
unsigned char[4*13 + 8] data
|
||||
# sizeof(flags_t) # flags
|
||||
# + sizeof(attr_t) # lang
|
||||
# + sizeof(attr_t) # id
|
||||
# + sizeof(attr_t) # length
|
||||
# + sizeof(attr_t) # orth
|
||||
# + sizeof(attr_t) # lower
|
||||
# + sizeof(attr_t) # norm
|
||||
# + sizeof(attr_t) # shape
|
||||
# + sizeof(attr_t) # prefix
|
||||
# + sizeof(attr_t) # suffix
|
||||
# + sizeof(attr_t) # cluster
|
||||
# + sizeof(float) # prob
|
||||
# + sizeof(float) # cluster
|
||||
# + sizeof(float) # l2_norm
|
||||
|
||||
|
||||
cdef struct Entity:
|
||||
hash_t id
|
||||
int start
|
||||
|
|
|
@ -12,6 +12,7 @@ from libc.math cimport exp
|
|||
cimport cython
|
||||
cimport cython.parallel
|
||||
import cytoolz
|
||||
import dill
|
||||
|
||||
import numpy.random
|
||||
cimport numpy as np
|
||||
|
@ -35,6 +36,7 @@ from thinc.api import layerize, chain
|
|||
from thinc.neural import Model, Affine, ELU, ReLu, Maxout
|
||||
from thinc.neural.ops import NumpyOps
|
||||
|
||||
from .. import util
|
||||
from ..util import get_async, get_cuda_stream
|
||||
from .._ml import zero_init, PrecomputableAffine, PrecomputableMaxouts
|
||||
from .._ml import Tok2Vec, doc2feats
|
||||
|
@ -218,9 +220,8 @@ cdef class Parser:
|
|||
"""
|
||||
@classmethod
|
||||
def Model(cls, nr_class, token_vector_width=128, hidden_width=128, **cfg):
|
||||
nr_context_tokens = StateClass.nr_context_tokens()
|
||||
lower = PrecomputableMaxouts(hidden_width,
|
||||
nF=nr_context_tokens,
|
||||
nF=cls.nr_feature,
|
||||
nI=token_vector_width,
|
||||
pieces=cfg.get('maxout_pieces', 1))
|
||||
|
||||
|
@ -267,7 +268,7 @@ cdef class Parser:
|
|||
self.model = model
|
||||
|
||||
def __reduce__(self):
|
||||
return (Parser, (self.vocab, self.moves, self.model, self.cfg), None, None)
|
||||
return (Parser, (self.vocab, self.moves, self.model), None, None)
|
||||
|
||||
def __call__(self, Doc tokens, state=None):
|
||||
"""
|
||||
|
@ -392,9 +393,11 @@ cdef class Parser:
|
|||
lower, stream, drop=dropout)
|
||||
return state2vec, upper
|
||||
|
||||
nr_feature = 13
|
||||
|
||||
def get_token_ids(self, states):
|
||||
cdef StateClass state
|
||||
cdef int n_tokens = states[0].nr_context_tokens()
|
||||
cdef int n_tokens = self.nr_feature
|
||||
ids = numpy.zeros((len(states), n_tokens), dtype='i', order='c')
|
||||
for i, state in enumerate(states):
|
||||
state.set_context_tokens(ids[i])
|
||||
|
@ -458,6 +461,22 @@ cdef class Parser:
|
|||
if self.model is True:
|
||||
self.model = self.Model(self.moves.n_moves, **cfg)
|
||||
|
||||
def to_disk(self, path):
|
||||
path = util.ensure_path(path)
|
||||
with (path / 'model.bin').open('wb') as file_:
|
||||
dill.dump(self.model, file_)
|
||||
|
||||
def from_disk(self, path):
|
||||
path = util.ensure_path(path)
|
||||
with (path / 'model.bin').open('wb') as file_:
|
||||
self.model = dill.load(file_)
|
||||
|
||||
def to_bytes(self):
|
||||
pass
|
||||
|
||||
def from_bytes(self, data):
|
||||
pass
|
||||
|
||||
|
||||
class ParserStateError(ValueError):
|
||||
def __init__(self, doc):
|
||||
|
|
|
@ -19,7 +19,6 @@ def test_pickle_string_store(stringstore, text1, text2):
|
|||
assert len(stringstore) == len(unpickled)
|
||||
|
||||
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.parametrize('text1,text2', [('dog', 'cat')])
|
||||
def test_pickle_vocab(text1, text2):
|
||||
vocab = Vocab(lex_attr_getters={int(NORM): lambda string: string[:-1]})
|
||||
|
|
|
@ -56,3 +56,18 @@ def test_vocab_lexeme_add_flag_provided_id(en_vocab):
|
|||
assert en_vocab['199'].check_flag(IS_DIGIT) == False
|
||||
assert en_vocab['the'].check_flag(is_len4) == False
|
||||
assert en_vocab['dogs'].check_flag(is_len4) == True
|
||||
|
||||
|
||||
def test_lexeme_bytes_roundtrip(en_vocab):
|
||||
one = en_vocab['one']
|
||||
alpha = en_vocab['alpha']
|
||||
assert one.orth != alpha.orth
|
||||
assert one.lower != alpha.lower
|
||||
print(one.orth, alpha.orth)
|
||||
alpha.from_bytes(one.to_bytes())
|
||||
|
||||
assert one.orth_ == alpha.orth_
|
||||
assert one.orth == alpha.orth
|
||||
assert one.lower == alpha.lower
|
||||
assert one.lower_ == alpha.lower_
|
||||
|
||||
|
|
|
@ -97,11 +97,10 @@ cdef class Tokenizer:
|
|||
def __reduce__(self):
|
||||
args = (self.vocab,
|
||||
self._rules,
|
||||
self._prefix_re,
|
||||
self._suffix_re,
|
||||
self._infix_re,
|
||||
self.prefix_search,
|
||||
self.suffix_search,
|
||||
self.infix_finditer,
|
||||
self.token_match)
|
||||
|
||||
return (self.__class__, args, None, None)
|
||||
|
||||
cpdef Doc tokens_from_list(self, list strings):
|
||||
|
|
|
@ -20,14 +20,16 @@ class Trainer(object):
|
|||
"""
|
||||
def __init__(self, nlp, gold_tuples):
|
||||
self.nlp = nlp
|
||||
self.gold_tuples = PseudoProjectivity.preprocess_training_data(gold_tuples)
|
||||
self.nr_epoch = 0
|
||||
self.optimizer = Adam(NumpyOps(), 0.001)
|
||||
self.gold_tuples = gold_tuples
|
||||
|
||||
def epochs(self, nr_epoch, augment_data=None, gold_preproc=False):
|
||||
cached_golds = {}
|
||||
def _epoch(indices):
|
||||
for i in tqdm.tqdm(indices):
|
||||
all_docs = []
|
||||
all_golds = []
|
||||
for i in indices:
|
||||
raw_text, paragraph_tuples = self.gold_tuples[i]
|
||||
if gold_preproc:
|
||||
raw_text = None
|
||||
|
@ -43,7 +45,11 @@ class Trainer(object):
|
|||
raw_text, paragraph_tuples = augment_data(raw_text, paragraph_tuples)
|
||||
docs = self.make_docs(raw_text, paragraph_tuples)
|
||||
golds = self.make_golds(docs, paragraph_tuples)
|
||||
yield docs, golds
|
||||
all_docs.extend(docs)
|
||||
all_golds.extend(golds)
|
||||
for batch in tqdm.tqdm(partition_all(12, zip(all_docs, all_golds))):
|
||||
X, y = zip(*batch)
|
||||
yield X, y
|
||||
|
||||
indices = list(range(len(self.gold_tuples)))
|
||||
for itn in range(nr_epoch):
|
||||
|
|
520
spacy/vocab.pyx
520
spacy/vocab.pyx
|
@ -5,7 +5,7 @@ import bz2
|
|||
import ujson
|
||||
import re
|
||||
|
||||
from libc.string cimport memset
|
||||
from libc.string cimport memset, memcpy
|
||||
from libc.stdint cimport int32_t
|
||||
from libc.math cimport sqrt
|
||||
from cymem.cymem cimport Address
|
||||
|
@ -13,9 +13,10 @@ from .lexeme cimport EMPTY_LEXEME
|
|||
from .lexeme cimport Lexeme
|
||||
from .strings cimport hash_string
|
||||
from .typedefs cimport attr_t
|
||||
from .cfile cimport CFile, StringCFile
|
||||
from .cfile cimport CFile
|
||||
from .tokens.token cimport Token
|
||||
from .attrs cimport PROB, LANG
|
||||
from .structs cimport SerializedLexemeC
|
||||
|
||||
from .compat import copy_reg, pickle
|
||||
from .lemmatizer import Lemmatizer
|
||||
|
@ -42,6 +43,7 @@ cdef class Vocab:
|
|||
def load(cls, path, lex_attr_getters=None, lemmatizer=True,
|
||||
tag_map=True, oov_prob=True, **deprecated_kwargs):
|
||||
"""
|
||||
Deprecated --- replace in spaCy 2
|
||||
Load the vocabulary from a path.
|
||||
|
||||
Arguments:
|
||||
|
@ -88,6 +90,7 @@ cdef class Vocab:
|
|||
self.load_lexemes(path / 'vocab' / 'lexemes.bin')
|
||||
return self
|
||||
|
||||
|
||||
def __init__(self, lex_attr_getters=None, tag_map=None, lemmatizer=None,
|
||||
strings=tuple(), **deprecated_kwargs):
|
||||
"""
|
||||
|
@ -149,24 +152,7 @@ cdef class Vocab:
|
|||
The current number of lexemes stored.
|
||||
"""
|
||||
return self.length
|
||||
|
||||
def resize_vectors(self, int new_size):
|
||||
"""
|
||||
Set vectors_length to a new size, and allocate more memory for the Lexeme
|
||||
vectors if necessary. The memory will be zeroed.
|
||||
|
||||
Arguments:
|
||||
new_size (int): The new size of the vectors.
|
||||
"""
|
||||
cdef hash_t key
|
||||
cdef size_t addr
|
||||
if new_size > self.vectors_length:
|
||||
for key, addr in self._by_hash.items():
|
||||
lex = <LexemeC*>addr
|
||||
lex.vector = <float*>self.mem.realloc(lex.vector,
|
||||
new_size * sizeof(lex.vector[0]))
|
||||
self.vectors_length = new_size
|
||||
|
||||
|
||||
def add_flag(self, flag_getter, int flag_id=-1):
|
||||
"""
|
||||
Set a new boolean flag to words in the vocabulary.
|
||||
|
@ -224,7 +210,7 @@ cdef class Vocab:
|
|||
if lex != NULL:
|
||||
if lex.orth != self.strings[string]:
|
||||
raise LookupError.mismatched_strings(
|
||||
lex.orth, self.strings[string], self.strings[lex.orth], string)
|
||||
lex.orth, self.strings[string], string)
|
||||
return lex
|
||||
else:
|
||||
return self._new_lexeme(mem, string)
|
||||
|
@ -337,148 +323,75 @@ cdef class Vocab:
|
|||
Token.set_struct_attr(token, attr_id, value)
|
||||
return tokens
|
||||
|
||||
def dump(self, loc=None):
|
||||
"""
|
||||
Save the lexemes binary data to the given location, or
|
||||
return a byte-string with the data if loc is None.
|
||||
def to_disk(self, path):
|
||||
path = util.ensure_path(path)
|
||||
if not path.exists():
|
||||
path.mkdir()
|
||||
strings_loc = path / 'strings.json'
|
||||
with strings_loc.open('w', encoding='utf8') as file_:
|
||||
self.strings.dump(file_)
|
||||
self.dump(path / 'lexemes.bin')
|
||||
|
||||
Arguments:
|
||||
loc (Path or None): The path to save to, or None.
|
||||
"""
|
||||
cdef CFile fp
|
||||
if loc is None:
|
||||
fp = StringCFile('wb')
|
||||
else:
|
||||
fp = CFile(loc, 'wb')
|
||||
cdef size_t st
|
||||
def from_disk(self, path):
|
||||
path = util.ensure_path(path)
|
||||
with (path / 'vocab' / 'strings.json').open('r', encoding='utf8') as file_:
|
||||
strings_list = ujson.load(file_)
|
||||
for string in strings_list:
|
||||
self.strings[string]
|
||||
self.load_lexemes(path / 'lexemes.bin')
|
||||
|
||||
def lexemes_to_bytes(self, **exclude):
|
||||
cdef hash_t key
|
||||
cdef size_t addr
|
||||
cdef hash_t key
|
||||
cdef LexemeC* lexeme = NULL
|
||||
cdef SerializedLexemeC lex_data
|
||||
cdef int size = 0
|
||||
for key, addr in self._by_hash.items():
|
||||
if addr == 0:
|
||||
continue
|
||||
size += sizeof(lex_data.data)
|
||||
byte_string = b'\0' * size
|
||||
byte_ptr = <unsigned char*>byte_string
|
||||
cdef int j
|
||||
cdef int i = 0
|
||||
for key, addr in self._by_hash.items():
|
||||
if addr == 0:
|
||||
continue
|
||||
lexeme = <LexemeC*>addr
|
||||
fp.write_from(&lexeme.orth, sizeof(lexeme.orth), 1)
|
||||
fp.write_from(&lexeme.flags, sizeof(lexeme.flags), 1)
|
||||
fp.write_from(&lexeme.id, sizeof(lexeme.id), 1)
|
||||
fp.write_from(&lexeme.length, sizeof(lexeme.length), 1)
|
||||
fp.write_from(&lexeme.orth, sizeof(lexeme.orth), 1)
|
||||
fp.write_from(&lexeme.lower, sizeof(lexeme.lower), 1)
|
||||
fp.write_from(&lexeme.norm, sizeof(lexeme.norm), 1)
|
||||
fp.write_from(&lexeme.shape, sizeof(lexeme.shape), 1)
|
||||
fp.write_from(&lexeme.prefix, sizeof(lexeme.prefix), 1)
|
||||
fp.write_from(&lexeme.suffix, sizeof(lexeme.suffix), 1)
|
||||
fp.write_from(&lexeme.cluster, sizeof(lexeme.cluster), 1)
|
||||
fp.write_from(&lexeme.prob, sizeof(lexeme.prob), 1)
|
||||
fp.write_from(&lexeme.sentiment, sizeof(lexeme.sentiment), 1)
|
||||
fp.write_from(&lexeme.l2_norm, sizeof(lexeme.l2_norm), 1)
|
||||
fp.write_from(&lexeme.lang, sizeof(lexeme.lang), 1)
|
||||
fp.close()
|
||||
if loc is None:
|
||||
return fp.string_data()
|
||||
lex_data = Lexeme.c_to_bytes(lexeme)
|
||||
for j in range(sizeof(lex_data.data)):
|
||||
byte_ptr[i] = lex_data.data[j]
|
||||
i += 1
|
||||
return byte_string
|
||||
|
||||
def load_lexemes(self, loc):
|
||||
def lexemes_from_bytes(self, bytes bytes_data):
|
||||
"""
|
||||
Load the binary vocabulary data from the given location.
|
||||
|
||||
Arguments:
|
||||
loc (Path): The path to load from.
|
||||
|
||||
Returns:
|
||||
None
|
||||
Load the binary vocabulary data from the given string.
|
||||
"""
|
||||
fp = CFile(loc, 'rb',
|
||||
on_open_error=lambda: IOError('LexemeCs file not found at %s' % loc))
|
||||
cdef LexemeC* lexeme = NULL
|
||||
cdef LexemeC* lexeme
|
||||
cdef hash_t key
|
||||
cdef unicode py_str
|
||||
cdef attr_t orth = 0
|
||||
assert sizeof(orth) == sizeof(lexeme.orth)
|
||||
i = 0
|
||||
while True:
|
||||
try:
|
||||
fp.read_into(&orth, 1, sizeof(orth))
|
||||
except IOError:
|
||||
break
|
||||
lexeme = <LexemeC*>self.mem.alloc(sizeof(LexemeC), 1)
|
||||
# Copy data from the file into the lexeme
|
||||
fp.read_into(&lexeme.flags, 1, sizeof(lexeme.flags))
|
||||
fp.read_into(&lexeme.id, 1, sizeof(lexeme.id))
|
||||
fp.read_into(&lexeme.length, 1, sizeof(lexeme.length))
|
||||
fp.read_into(&lexeme.orth, 1, sizeof(lexeme.orth))
|
||||
fp.read_into(&lexeme.lower, 1, sizeof(lexeme.lower))
|
||||
fp.read_into(&lexeme.norm, 1, sizeof(lexeme.norm))
|
||||
fp.read_into(&lexeme.shape, 1, sizeof(lexeme.shape))
|
||||
fp.read_into(&lexeme.prefix, 1, sizeof(lexeme.prefix))
|
||||
fp.read_into(&lexeme.suffix, 1, sizeof(lexeme.suffix))
|
||||
fp.read_into(&lexeme.cluster, 1, sizeof(lexeme.cluster))
|
||||
fp.read_into(&lexeme.prob, 1, sizeof(lexeme.prob))
|
||||
fp.read_into(&lexeme.sentiment, 1, sizeof(lexeme.sentiment))
|
||||
fp.read_into(&lexeme.l2_norm, 1, sizeof(lexeme.l2_norm))
|
||||
fp.read_into(&lexeme.lang, 1, sizeof(lexeme.lang))
|
||||
cdef int i = 0
|
||||
cdef int j = 0
|
||||
cdef SerializedLexemeC lex_data
|
||||
chunk_size = sizeof(lex_data.data)
|
||||
cdef unsigned char* bytes_ptr = bytes_data
|
||||
for i in range(0, len(bytes_data), chunk_size):
|
||||
lexeme = <LexemeC*>self.mem.alloc(1, sizeof(LexemeC))
|
||||
for j in range(sizeof(lex_data.data)):
|
||||
lex_data.data[j] = bytes_ptr[i+j]
|
||||
Lexeme.c_from_bytes(lexeme, lex_data)
|
||||
|
||||
lexeme.vector = EMPTY_VEC
|
||||
py_str = self.strings[lexeme.orth]
|
||||
assert self.strings[py_str] == lexeme.orth, (py_str, lexeme.orth)
|
||||
key = hash_string(py_str)
|
||||
self._by_hash.set(key, lexeme)
|
||||
self._by_orth.set(lexeme.orth, lexeme)
|
||||
self.length += 1
|
||||
i += 1
|
||||
fp.close()
|
||||
|
||||
def _deserialize_lexemes(self, CFile fp):
|
||||
"""
|
||||
Load the binary vocabulary data from the given CFile.
|
||||
"""
|
||||
cdef LexemeC* lexeme = NULL
|
||||
cdef hash_t key
|
||||
cdef unicode py_str
|
||||
cdef attr_t orth = 0
|
||||
assert sizeof(orth) == sizeof(lexeme.orth)
|
||||
i = 0
|
||||
cdef int todo = fp.size
|
||||
cdef int lex_size = sizeof(lexeme.flags)
|
||||
lex_size += sizeof(lexeme.id)
|
||||
lex_size += sizeof(lexeme.length)
|
||||
lex_size += sizeof(lexeme.orth)
|
||||
lex_size += sizeof(lexeme.lower)
|
||||
lex_size += sizeof(lexeme.norm)
|
||||
lex_size += sizeof(lexeme.shape)
|
||||
lex_size += sizeof(lexeme.prefix)
|
||||
lex_size += sizeof(lexeme.suffix)
|
||||
lex_size += sizeof(lexeme.cluster)
|
||||
lex_size += sizeof(lexeme.prob)
|
||||
lex_size += sizeof(lexeme.sentiment)
|
||||
lex_size += sizeof(lexeme.l2_norm)
|
||||
lex_size += sizeof(lexeme.lang)
|
||||
while True:
|
||||
if todo < lex_size:
|
||||
break
|
||||
todo -= lex_size
|
||||
lexeme = <LexemeC*>self.mem.alloc(sizeof(LexemeC), 1)
|
||||
# Copy data from the file into the lexeme
|
||||
fp.read_into(&lexeme.flags, 1, sizeof(lexeme.flags))
|
||||
fp.read_into(&lexeme.id, 1, sizeof(lexeme.id))
|
||||
fp.read_into(&lexeme.length, 1, sizeof(lexeme.length))
|
||||
fp.read_into(&lexeme.orth, 1, sizeof(lexeme.orth))
|
||||
fp.read_into(&lexeme.lower, 1, sizeof(lexeme.lower))
|
||||
fp.read_into(&lexeme.norm, 1, sizeof(lexeme.norm))
|
||||
fp.read_into(&lexeme.shape, 1, sizeof(lexeme.shape))
|
||||
fp.read_into(&lexeme.prefix, 1, sizeof(lexeme.prefix))
|
||||
fp.read_into(&lexeme.suffix, 1, sizeof(lexeme.suffix))
|
||||
fp.read_into(&lexeme.cluster, 1, sizeof(lexeme.cluster))
|
||||
fp.read_into(&lexeme.prob, 1, sizeof(lexeme.prob))
|
||||
fp.read_into(&lexeme.sentiment, 1, sizeof(lexeme.sentiment))
|
||||
fp.read_into(&lexeme.l2_norm, 1, sizeof(lexeme.l2_norm))
|
||||
fp.read_into(&lexeme.lang, 1, sizeof(lexeme.lang))
|
||||
|
||||
lexeme.vector = EMPTY_VEC
|
||||
py_str = self.strings[lexeme.orth]
|
||||
key = hash_string(py_str)
|
||||
self._by_hash.set(key, lexeme)
|
||||
self._by_orth.set(lexeme.orth, lexeme)
|
||||
self.length += 1
|
||||
i += 1
|
||||
fp.close()
|
||||
|
||||
# Deprecated --- delete these once stable
|
||||
|
||||
def dump_vectors(self, out_loc):
|
||||
"""
|
||||
Save the word vectors to a binary file.
|
||||
|
@ -487,7 +400,7 @@ cdef class Vocab:
|
|||
loc (Path): The path to save to.
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
#"""
|
||||
cdef int32_t vec_len = self.vectors_length
|
||||
cdef int32_t word_len
|
||||
cdef bytes word_str
|
||||
|
@ -508,6 +421,8 @@ cdef class Vocab:
|
|||
out_file.write_from(vec, vec_len, sizeof(float))
|
||||
out_file.close()
|
||||
|
||||
|
||||
|
||||
def load_vectors(self, file_):
|
||||
"""
|
||||
Load vectors from a text-based file.
|
||||
|
@ -610,38 +525,22 @@ cdef class Vocab:
|
|||
return vec_len
|
||||
|
||||
|
||||
def pickle_vocab(vocab):
|
||||
sstore = vocab.strings
|
||||
morph = vocab.morphology
|
||||
length = vocab.length
|
||||
data_dir = vocab.data_dir
|
||||
lex_attr_getters = vocab.lex_attr_getters
|
||||
def resize_vectors(self, int new_size):
|
||||
"""
|
||||
Set vectors_length to a new size, and allocate more memory for the Lexeme
|
||||
vectors if necessary. The memory will be zeroed.
|
||||
|
||||
lexemes_data = vocab.dump()
|
||||
vectors_length = vocab.vectors_length
|
||||
|
||||
return (unpickle_vocab,
|
||||
(sstore, morph, data_dir, lex_attr_getters,
|
||||
lexemes_data, length, vectors_length))
|
||||
|
||||
|
||||
def unpickle_vocab(sstore, morphology, data_dir,
|
||||
lex_attr_getters, bytes lexemes_data, int length, int vectors_length):
|
||||
cdef Vocab vocab = Vocab()
|
||||
vocab.length = length
|
||||
vocab.vectors_length = vectors_length
|
||||
vocab.strings = sstore
|
||||
cdef CFile fp = StringCFile('r', data=lexemes_data)
|
||||
vocab.morphology = morphology
|
||||
vocab.data_dir = data_dir
|
||||
vocab.lex_attr_getters = lex_attr_getters
|
||||
vocab._deserialize_lexemes(fp)
|
||||
vocab.length = length
|
||||
vocab.vectors_length = vectors_length
|
||||
return vocab
|
||||
|
||||
|
||||
copy_reg.pickle(Vocab, pickle_vocab, unpickle_vocab)
|
||||
Arguments:
|
||||
new_size (int): The new size of the vectors.
|
||||
"""
|
||||
cdef hash_t key
|
||||
cdef size_t addr
|
||||
if new_size > self.vectors_length:
|
||||
for key, addr in self._by_hash.items():
|
||||
lex = <LexemeC*>addr
|
||||
lex.vector = <float*>self.mem.realloc(lex.vector,
|
||||
new_size * sizeof(lex.vector[0]))
|
||||
self.vectors_length = new_size
|
||||
|
||||
|
||||
def write_binary_vectors(in_loc, out_loc):
|
||||
|
@ -670,6 +569,39 @@ def write_binary_vectors(in_loc, out_loc):
|
|||
out_file.write_from(vec, vec_len, sizeof(float))
|
||||
|
||||
|
||||
def pickle_vocab(vocab):
|
||||
sstore = vocab.strings
|
||||
morph = vocab.morphology
|
||||
length = vocab.length
|
||||
data_dir = vocab.data_dir
|
||||
lex_attr_getters = vocab.lex_attr_getters
|
||||
|
||||
lexemes_data = vocab.lexemes_to_bytes()
|
||||
vectors_length = vocab.vectors_length
|
||||
|
||||
return (unpickle_vocab,
|
||||
(sstore, morph, data_dir, lex_attr_getters,
|
||||
lexemes_data, length, vectors_length))
|
||||
|
||||
|
||||
def unpickle_vocab(sstore, morphology, data_dir,
|
||||
lex_attr_getters, bytes lexemes_data, int length, int vectors_length):
|
||||
cdef Vocab vocab = Vocab()
|
||||
vocab.length = length
|
||||
vocab.vectors_length = vectors_length
|
||||
vocab.strings = sstore
|
||||
vocab.morphology = morphology
|
||||
vocab.data_dir = data_dir
|
||||
vocab.lex_attr_getters = lex_attr_getters
|
||||
vocab.lexemes_from_bytes(lexemes_data)
|
||||
vocab.length = length
|
||||
vocab.vectors_length = vectors_length
|
||||
return vocab
|
||||
|
||||
|
||||
copy_reg.pickle(Vocab, pickle_vocab, unpickle_vocab)
|
||||
|
||||
|
||||
class LookupError(Exception):
|
||||
@classmethod
|
||||
def mismatched_strings(cls, id_, id_string, original_string):
|
||||
|
@ -701,3 +633,237 @@ class VectorReadError(Exception):
|
|||
"Vector size: %d\n"
|
||||
"Max size: %d\n"
|
||||
"Min size: 1\n" % (loc, size, MAX_VEC_SIZE))
|
||||
|
||||
|
||||
#
|
||||
#Deprecated --- delete these once stable
|
||||
#
|
||||
# def dump_vectors(self, out_loc):
|
||||
# """
|
||||
# Save the word vectors to a binary file.
|
||||
#
|
||||
# Arguments:
|
||||
# loc (Path): The path to save to.
|
||||
# Returns:
|
||||
# None
|
||||
# #"""
|
||||
# cdef int32_t vec_len = self.vectors_length
|
||||
# cdef int32_t word_len
|
||||
# cdef bytes word_str
|
||||
# cdef char* chars
|
||||
#
|
||||
# cdef Lexeme lexeme
|
||||
# cdef CFile out_file = CFile(out_loc, 'wb')
|
||||
# for lexeme in self:
|
||||
# word_str = lexeme.orth_.encode('utf8')
|
||||
# vec = lexeme.c.vector
|
||||
# word_len = len(word_str)
|
||||
#
|
||||
# out_file.write_from(&word_len, 1, sizeof(word_len))
|
||||
# out_file.write_from(&vec_len, 1, sizeof(vec_len))
|
||||
#
|
||||
# chars = <char*>word_str
|
||||
# out_file.write_from(chars, word_len, sizeof(char))
|
||||
# out_file.write_from(vec, vec_len, sizeof(float))
|
||||
# out_file.close()
|
||||
#
|
||||
#
|
||||
#
|
||||
# def load_vectors(self, file_):
|
||||
# """
|
||||
# Load vectors from a text-based file.
|
||||
#
|
||||
# Arguments:
|
||||
# file_ (buffer): The file to read from. Entries should be separated by newlines,
|
||||
# and each entry should be whitespace delimited. The first value of the entry
|
||||
# should be the word string, and subsequent entries should be the values of the
|
||||
# vector.
|
||||
#
|
||||
# Returns:
|
||||
# vec_len (int): The length of the vectors loaded.
|
||||
# """
|
||||
# cdef LexemeC* lexeme
|
||||
# cdef attr_t orth
|
||||
# cdef int32_t vec_len = -1
|
||||
# cdef double norm = 0.0
|
||||
#
|
||||
# whitespace_pattern = re.compile(r'\s', re.UNICODE)
|
||||
#
|
||||
# for line_num, line in enumerate(file_):
|
||||
# pieces = line.split()
|
||||
# word_str = " " if whitespace_pattern.match(line) else pieces.pop(0)
|
||||
# if vec_len == -1:
|
||||
# vec_len = len(pieces)
|
||||
# elif vec_len != len(pieces):
|
||||
# raise VectorReadError.mismatched_sizes(file_, line_num,
|
||||
# vec_len, len(pieces))
|
||||
# orth = self.strings[word_str]
|
||||
# lexeme = <LexemeC*><void*>self.get_by_orth(self.mem, orth)
|
||||
# lexeme.vector = <float*>self.mem.alloc(vec_len, sizeof(float))
|
||||
# for i, val_str in enumerate(pieces):
|
||||
# lexeme.vector[i] = float(val_str)
|
||||
# norm = 0.0
|
||||
# for i in range(vec_len):
|
||||
# norm += lexeme.vector[i] * lexeme.vector[i]
|
||||
# lexeme.l2_norm = sqrt(norm)
|
||||
# self.vectors_length = vec_len
|
||||
# return vec_len
|
||||
#
|
||||
# def load_vectors_from_bin_loc(self, loc):
|
||||
# """
|
||||
# Load vectors from the location of a binary file.
|
||||
#
|
||||
# Arguments:
|
||||
# loc (unicode): The path of the binary file to load from.
|
||||
#
|
||||
# Returns:
|
||||
# vec_len (int): The length of the vectors loaded.
|
||||
# """
|
||||
# cdef CFile file_ = CFile(loc, b'rb')
|
||||
# cdef int32_t word_len
|
||||
# cdef int32_t vec_len = 0
|
||||
# cdef int32_t prev_vec_len = 0
|
||||
# cdef float* vec
|
||||
# cdef Address mem
|
||||
# cdef attr_t string_id
|
||||
# cdef bytes py_word
|
||||
# cdef vector[float*] vectors
|
||||
# cdef int line_num = 0
|
||||
# cdef Pool tmp_mem = Pool()
|
||||
# while True:
|
||||
# try:
|
||||
# file_.read_into(&word_len, sizeof(word_len), 1)
|
||||
# except IOError:
|
||||
# break
|
||||
# file_.read_into(&vec_len, sizeof(vec_len), 1)
|
||||
# if prev_vec_len != 0 and vec_len != prev_vec_len:
|
||||
# raise VectorReadError.mismatched_sizes(loc, line_num,
|
||||
# vec_len, prev_vec_len)
|
||||
# if 0 >= vec_len >= MAX_VEC_SIZE:
|
||||
# raise VectorReadError.bad_size(loc, vec_len)
|
||||
#
|
||||
# chars = <char*>file_.alloc_read(tmp_mem, word_len, sizeof(char))
|
||||
# vec = <float*>file_.alloc_read(self.mem, vec_len, sizeof(float))
|
||||
#
|
||||
# string_id = self.strings[chars[:word_len]]
|
||||
# # Insert words into vocab to add vector.
|
||||
# self.get_by_orth(self.mem, string_id)
|
||||
# while string_id >= vectors.size():
|
||||
# vectors.push_back(EMPTY_VEC)
|
||||
# assert vec != NULL
|
||||
# vectors[string_id] = vec
|
||||
# line_num += 1
|
||||
# cdef LexemeC* lex
|
||||
# cdef size_t lex_addr
|
||||
# cdef double norm = 0.0
|
||||
# cdef int i
|
||||
# for orth, lex_addr in self._by_orth.items():
|
||||
# lex = <LexemeC*>lex_addr
|
||||
# if lex.lower < vectors.size():
|
||||
# lex.vector = vectors[lex.lower]
|
||||
# norm = 0.0
|
||||
# for i in range(vec_len):
|
||||
# norm += lex.vector[i] * lex.vector[i]
|
||||
# lex.l2_norm = sqrt(norm)
|
||||
# else:
|
||||
# lex.vector = EMPTY_VEC
|
||||
# self.vectors_length = vec_len
|
||||
# return vec_len
|
||||
#
|
||||
#
|
||||
#def write_binary_vectors(in_loc, out_loc):
|
||||
# cdef CFile out_file = CFile(out_loc, 'wb')
|
||||
# cdef Address mem
|
||||
# cdef int32_t word_len
|
||||
# cdef int32_t vec_len
|
||||
# cdef char* chars
|
||||
# with bz2.BZ2File(in_loc, 'r') as file_:
|
||||
# for line in file_:
|
||||
# pieces = line.split()
|
||||
# word = pieces.pop(0)
|
||||
# mem = Address(len(pieces), sizeof(float))
|
||||
# vec = <float*>mem.ptr
|
||||
# for i, val_str in enumerate(pieces):
|
||||
# vec[i] = float(val_str)
|
||||
#
|
||||
# word_len = len(word)
|
||||
# vec_len = len(pieces)
|
||||
#
|
||||
# out_file.write_from(&word_len, 1, sizeof(word_len))
|
||||
# out_file.write_from(&vec_len, 1, sizeof(vec_len))
|
||||
#
|
||||
# chars = <char*>word
|
||||
# out_file.write_from(chars, len(word), sizeof(char))
|
||||
# out_file.write_from(vec, vec_len, sizeof(float))
|
||||
#
|
||||
#
|
||||
# def resize_vectors(self, int new_size):
|
||||
# """
|
||||
# Set vectors_length to a new size, and allocate more memory for the Lexeme
|
||||
# vectors if necessary. The memory will be zeroed.
|
||||
#
|
||||
# Arguments:
|
||||
# new_size (int): The new size of the vectors.
|
||||
# """
|
||||
# cdef hash_t key
|
||||
# cdef size_t addr
|
||||
# if new_size > self.vectors_length:
|
||||
# for key, addr in self._by_hash.items():
|
||||
# lex = <LexemeC*>addr
|
||||
# lex.vector = <float*>self.mem.realloc(lex.vector,
|
||||
# new_size * sizeof(lex.vector[0]))
|
||||
# self.vectors_length = new_size
|
||||
#
|
||||
#
|
||||
|
||||
#
|
||||
# def dump(self, loc=None):
|
||||
# """
|
||||
# Save the lexemes binary data to the given location, or
|
||||
# return a byte-string with the data if loc is None.
|
||||
#
|
||||
# Arguments:
|
||||
# loc (Path or None): The path to save to, or None.
|
||||
# """
|
||||
# if loc is None:
|
||||
# return self.to_bytes()
|
||||
# else:
|
||||
# return self.to_disk(loc)
|
||||
#
|
||||
# def load_lexemes(self, loc):
|
||||
# """
|
||||
# Load the binary vocabulary data from the given location.
|
||||
#
|
||||
# Arguments:
|
||||
# loc (Path): The path to load from.
|
||||
#
|
||||
# Returns:
|
||||
# None
|
||||
# """
|
||||
# fp = CFile(loc, 'rb',
|
||||
# on_open_error=lambda: IOError('LexemeCs file not found at %s' % loc))
|
||||
# cdef LexemeC* lexeme = NULL
|
||||
# cdef SerializedLexemeC lex_data
|
||||
# cdef hash_t key
|
||||
# cdef unicode py_str
|
||||
# cdef attr_t orth = 0
|
||||
# assert sizeof(orth) == sizeof(lexeme.orth)
|
||||
# i = 0
|
||||
# while True:
|
||||
# try:
|
||||
# fp.read_into(&orth, 1, sizeof(orth))
|
||||
# except IOError:
|
||||
# break
|
||||
# lexeme = <LexemeC*>self.mem.alloc(sizeof(LexemeC), 1)
|
||||
# # Copy data from the file into the lexeme
|
||||
# fp.read_into(&lex_data.data, 1, sizeof(lex_data.data))
|
||||
# Lexeme.c_from_bytes(lexeme, lex_data)
|
||||
#
|
||||
# lexeme.vector = EMPTY_VEC
|
||||
# py_str = self.strings[lexeme.orth]
|
||||
# key = hash_string(py_str)
|
||||
# self._by_hash.set(key, lexeme)
|
||||
# self._by_orth.set(lexeme.orth, lexeme)
|
||||
# self.length += 1
|
||||
# i += 1
|
||||
# fp.close()
|
||||
|
|
Loading…
Reference in New Issue
Block a user