mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 16:07:41 +03:00 
			
		
		
		
	Add docstrings for Pipe API
This commit is contained in:
		
						commit
						8eb0b7b779
					
				
							
								
								
									
										90
									
								
								spacy/_ml.py
									
									
									
									
									
								
							
							
						
						
									
										90
									
								
								spacy/_ml.py
									
									
									
									
									
								
							|  | @ -4,6 +4,7 @@ from thinc.neural import Model, Maxout, Softmax, Affine | |||
| from thinc.neural._classes.hash_embed import HashEmbed | ||||
| from thinc.neural.ops import NumpyOps, CupyOps | ||||
| from thinc.neural.util import get_array_module | ||||
| import thinc.extra.load_nlp | ||||
| import random | ||||
| import cytoolz | ||||
| 
 | ||||
|  | @ -31,6 +32,7 @@ from . import util | |||
| import numpy | ||||
| import io | ||||
| 
 | ||||
| VECTORS_KEY = 'spacy_pretrained_vectors' | ||||
| 
 | ||||
| @layerize | ||||
| def _flatten_add_lengths(seqs, pad=0, drop=0.): | ||||
|  | @ -225,45 +227,52 @@ def drop_layer(layer, factor=2.): | |||
|     model.predict = layer | ||||
|     return model | ||||
| 
 | ||||
| def link_vectors_to_models(vocab): | ||||
|     vectors = vocab.vectors | ||||
|     ops = Model.ops | ||||
|     for word in vocab: | ||||
|         if word.orth in vectors.key2row: | ||||
|             word.rank = vectors.key2row[word.orth] | ||||
|         else: | ||||
|             word.rank = 0 | ||||
|     data = ops.asarray(vectors.data) | ||||
|     # Set an entry here, so that vectors are accessed by StaticVectors | ||||
|     # (unideal, I know) | ||||
|     thinc.extra.load_nlp.VECTORS[(ops.device, VECTORS_KEY)] = data | ||||
| 
 | ||||
| def Tok2Vec(width, embed_size, pretrained_dims=0, **kwargs): | ||||
|     assert pretrained_dims is not None | ||||
| 
 | ||||
| def Tok2Vec(width, embed_size, **kwargs): | ||||
|     pretrained_dims = kwargs.get('pretrained_dims', 0) | ||||
|     cnn_maxout_pieces = kwargs.get('cnn_maxout_pieces', 3) | ||||
|     cols = [ID, NORM, PREFIX, SUFFIX, SHAPE, ORTH] | ||||
|     with Model.define_operators({'>>': chain, '|': concatenate, '**': clone, '+': add}): | ||||
|     with Model.define_operators({'>>': chain, '|': concatenate, '**': clone, '+': add, | ||||
|                                  '*': reapply}): | ||||
|         norm = HashEmbed(width, embed_size, column=cols.index(NORM), name='embed_norm') | ||||
|         prefix = HashEmbed(width, embed_size//2, column=cols.index(PREFIX), name='embed_prefix') | ||||
|         suffix = HashEmbed(width, embed_size//2, column=cols.index(SUFFIX), name='embed_suffix') | ||||
|         shape = HashEmbed(width, embed_size//2, column=cols.index(SHAPE), name='embed_shape') | ||||
|         if pretrained_dims is not None and pretrained_dims >= 1: | ||||
|             glove = StaticVectors(VECTORS_KEY, width, column=cols.index(ID)) | ||||
| 
 | ||||
|             embed = uniqued( | ||||
|                 (glove | norm | prefix | suffix | shape) | ||||
|                 >> LN(Maxout(width, width*5, pieces=3)), column=5) | ||||
|         else: | ||||
|             embed = uniqued( | ||||
|                 (norm | prefix | suffix | shape) | ||||
|                 >> LN(Maxout(width, width*4, pieces=3)), column=5) | ||||
| 
 | ||||
| 
 | ||||
|         trained_vectors = ( | ||||
|             FeatureExtracter(cols) | ||||
|             >> with_flatten( | ||||
|                 uniqued( | ||||
|                     (norm | prefix | suffix | shape) | ||||
|                     >> LN(Maxout(width, width*4, pieces=3)), column=5) | ||||
|             ) | ||||
|         ) | ||||
|         convolution = Residual( | ||||
|             ExtractWindow(nW=1) | ||||
|             >> LN(Maxout(width, width*3, pieces=cnn_maxout_pieces)) | ||||
|         ) | ||||
| 
 | ||||
|         if pretrained_dims >= 1: | ||||
|             embed = concatenate_lists(trained_vectors, SpacyVectors) | ||||
|             tok2vec = ( | ||||
|                 embed | ||||
|                 >> with_flatten( | ||||
|                     Affine(width, width+pretrained_dims) | ||||
|                     >> convolution ** 4, | ||||
|                     pad=4) | ||||
|             ) | ||||
|         else: | ||||
|             embed = trained_vectors | ||||
|             tok2vec = ( | ||||
|                 embed | ||||
|                 >> with_flatten(convolution ** 4, pad=4) | ||||
|             ) | ||||
|         tok2vec = ( | ||||
|             FeatureExtracter(cols) | ||||
|             >> with_flatten( | ||||
|                 embed >> (convolution * 4), pad=4) | ||||
|         ) | ||||
| 
 | ||||
|         # Work around thinc API limitations :(. TODO: Revise in Thinc 7 | ||||
|         tok2vec.nO = width | ||||
|  | @ -271,6 +280,28 @@ def Tok2Vec(width, embed_size, pretrained_dims=0, **kwargs): | |||
|     return tok2vec | ||||
| 
 | ||||
| 
 | ||||
| def reapply(layer, n_times): | ||||
|     def reapply_fwd(X, drop=0.): | ||||
|         backprops = [] | ||||
|         for i in range(n_times): | ||||
|             Y, backprop = layer.begin_update(X, drop=drop) | ||||
|             X = Y | ||||
|             backprops.append(backprop) | ||||
|         def reapply_bwd(dY, sgd=None): | ||||
|             dX = None | ||||
|             for backprop in reversed(backprops): | ||||
|                 dY = backprop(dY, sgd=sgd) | ||||
|                 if dX is None: | ||||
|                     dX = dY | ||||
|                 else: | ||||
|                     dX += dY | ||||
|             return dX | ||||
|         return Y, reapply_bwd | ||||
|     return wrap(reapply_fwd, layer) | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| def asarray(ops, dtype): | ||||
|     def forward(X, drop=0.): | ||||
|         return ops.asarray(X, dtype=dtype), None | ||||
|  | @ -474,8 +505,13 @@ def getitem(i): | |||
|         return X[i], None | ||||
|     return layerize(getitem_fwd) | ||||
| 
 | ||||
| def build_tagger_model(nr_class, token_vector_width, pretrained_dims=0, **cfg): | ||||
| def build_tagger_model(nr_class, **cfg): | ||||
|     embed_size = util.env_opt('embed_size', 4000) | ||||
|     if 'token_vector_width' in cfg: | ||||
|         token_vector_width = cfg['token_vector_width'] | ||||
|     else: | ||||
|         token_vector_width = util.env_opt('token_vector_width', 128) | ||||
|     pretrained_dims = cfg.get('pretrained_dims', 0) | ||||
|     with Model.define_operators({'>>': chain, '+': add}): | ||||
|         # Input: (doc, tensor) tuples | ||||
|         private_tok2vec = Tok2Vec(token_vector_width, embed_size, | ||||
|  |  | |||
|  | @ -30,14 +30,14 @@ from ..compat import json_dumps | |||
|     n_iter=("number of iterations", "option", "n", int), | ||||
|     n_sents=("number of sentences", "option", "ns", int), | ||||
|     use_gpu=("Use GPU", "option", "g", int), | ||||
|     resume=("Whether to resume training", "flag", "R", bool), | ||||
|     vectors=("Model to load vectors from", "option", "v"), | ||||
|     no_tagger=("Don't train tagger", "flag", "T", bool), | ||||
|     no_parser=("Don't train parser", "flag", "P", bool), | ||||
|     no_entities=("Don't train NER", "flag", "N", bool), | ||||
|     gold_preproc=("Use gold preprocessing", "flag", "G", bool), | ||||
| ) | ||||
| def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0, | ||||
|           use_gpu=-1, resume=False, no_tagger=False, no_parser=False, no_entities=False, | ||||
|           use_gpu=-1, vectors=None, no_tagger=False, no_parser=False, no_entities=False, | ||||
|           gold_preproc=False): | ||||
|     """ | ||||
|     Train a model. Expects data in spaCy's JSON format. | ||||
|  | @ -73,25 +73,20 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0, | |||
|     corpus = GoldCorpus(train_path, dev_path, limit=n_sents) | ||||
|     n_train_words = corpus.count_train() | ||||
| 
 | ||||
|     if not resume: | ||||
|         lang_class = util.get_lang_class(lang) | ||||
|         nlp = lang_class(pipeline=pipeline) | ||||
|         optimizer = nlp.begin_training(lambda: corpus.train_tuples, device=use_gpu) | ||||
|     else: | ||||
|         print("Load resume") | ||||
|         util.use_gpu(use_gpu) | ||||
|         nlp = _resume_model(lang, pipeline, corpus) | ||||
|         optimizer = nlp.resume_training(device=use_gpu) | ||||
|         lang_class = nlp.__class__ | ||||
| 
 | ||||
|     lang_class = util.get_lang_class(lang) | ||||
|     nlp = lang_class(pipeline=pipeline) | ||||
|     if vectors: | ||||
|         util.load_model(vectors, vocab=nlp.vocab) | ||||
|     optimizer = nlp.begin_training(lambda: corpus.train_tuples, device=use_gpu) | ||||
|     nlp._optimizer = None | ||||
| 
 | ||||
|     print("Itn.\tLoss\tUAS\tNER P.\tNER R.\tNER F.\tTag %\tToken %") | ||||
|     try: | ||||
|         train_docs = corpus.train_docs(nlp, projectivize=True, noise_level=0.0, | ||||
|                                        gold_preproc=gold_preproc, max_length=0) | ||||
|         train_docs = list(train_docs) | ||||
|         for i in range(n_iter): | ||||
|             with tqdm.tqdm(total=n_train_words, leave=False) as pbar: | ||||
|                 train_docs = corpus.train_docs(nlp, projectivize=True, noise_level=0.0, | ||||
|                                                gold_preproc=gold_preproc, max_length=0) | ||||
|                 losses = {} | ||||
|                 for batch in minibatch(train_docs, size=batch_sizes): | ||||
|                     docs, golds = zip(*batch) | ||||
|  | @ -104,8 +99,8 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0, | |||
|                 util.set_env_log(False) | ||||
|                 epoch_model_path = output_path / ('model%d' % i) | ||||
|                 nlp.to_disk(epoch_model_path) | ||||
|                 #nlp_loaded = lang_class(pipeline=pipeline) | ||||
|                 #nlp_loaded = nlp_loaded.from_disk(epoch_model_path) | ||||
|                 nlp_loaded = lang_class(pipeline=pipeline) | ||||
|                 nlp_loaded = nlp_loaded.from_disk(epoch_model_path) | ||||
|                 scorer = nlp.evaluate( | ||||
|                             corpus.dev_docs( | ||||
|                                 nlp, | ||||
|  | @ -124,26 +119,6 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0, | |||
|         except: | ||||
|             pass | ||||
| 
 | ||||
| 
 | ||||
| def _resume_model(lang, pipeline, corpus): | ||||
|     nlp = util.load_model(lang) | ||||
|     pipes = {getattr(pipe, 'name', None) for pipe in nlp.pipeline} | ||||
|     for name in pipeline: | ||||
|         if name not in pipes: | ||||
|             factory = nlp.Defaults.factories[name] | ||||
|             for pipe in factory(nlp): | ||||
|                 if hasattr(pipe, 'begin_training'): | ||||
|                     pipe.begin_training(corpus.train_tuples, | ||||
|                                         pipeline=nlp.pipeline) | ||||
|                 nlp.pipeline.append(pipe) | ||||
|     nlp.meta['pipeline'] = pipeline | ||||
|     if nlp.vocab.vectors.data.shape[1] >= 1: | ||||
|         nlp.vocab.vectors.data = Model.ops.asarray( | ||||
|                                     nlp.vocab.vectors.data) | ||||
| 
 | ||||
|     return nlp | ||||
| 
 | ||||
| 
 | ||||
| def _render_parses(i, to_render): | ||||
|     to_render[0].user_data['title'] = "Batch %d" % i | ||||
|     with Path('/tmp/entities.html').open('w') as file_: | ||||
|  |  | |||
|  | @ -362,7 +362,6 @@ class Language(object): | |||
|         self._optimizer.device = device | ||||
|         return self._optimizer | ||||
| 
 | ||||
| 
 | ||||
|     def begin_training(self, get_gold_tuples=None, **cfg): | ||||
|         """Allocate models, pre-process training data and acquire a trainer and | ||||
|         optimizer. Used as a contextmanager. | ||||
|  |  | |||
|  | @ -43,6 +43,7 @@ from .compat import json_dumps | |||
| from .attrs import ID, LOWER, PREFIX, SUFFIX, SHAPE, TAG, DEP, POS | ||||
| from ._ml import rebatch, Tok2Vec, flatten | ||||
| from ._ml import build_text_classifier, build_tagger_model | ||||
| from ._ml import link_vectors_to_models | ||||
| from .parts_of_speech import X | ||||
| 
 | ||||
| 
 | ||||
|  | @ -146,6 +147,7 @@ class BaseThincComponent(object): | |||
|         If no model has been initialized yet, the model is added.''' | ||||
|         if self.model is True: | ||||
|             self.model = self.Model(**self.cfg) | ||||
|         link_vectors_to_models(self.vocab) | ||||
| 
 | ||||
|     def use_params(self, params): | ||||
|         '''Modify the pipe's model, to use the given parameter values. | ||||
|  | @ -172,8 +174,8 @@ class BaseThincComponent(object): | |||
| 
 | ||||
|         deserialize = OrderedDict(( | ||||
|             ('cfg', lambda b: self.cfg.update(ujson.loads(b))), | ||||
|             ('model', load_model), | ||||
|             ('vocab', lambda b: self.vocab.from_bytes(b)) | ||||
|             ('model', load_model), | ||||
|         )) | ||||
|         util.from_bytes(bytes_data, deserialize, exclude) | ||||
|         return self | ||||
|  | @ -182,8 +184,8 @@ class BaseThincComponent(object): | |||
|         '''Serialize the pipe to disk.''' | ||||
|         serialize = OrderedDict(( | ||||
|             ('cfg', lambda p: p.open('w').write(json_dumps(self.cfg))), | ||||
|             ('vocab', lambda p: self.vocab.to_disk(p)), | ||||
|             ('model', lambda p: p.open('wb').write(self.model.to_bytes())), | ||||
|             ('vocab', lambda p: self.vocab.to_disk(p)) | ||||
|         )) | ||||
|         util.to_disk(path, serialize, exclude) | ||||
| 
 | ||||
|  | @ -197,8 +199,8 @@ class BaseThincComponent(object): | |||
| 
 | ||||
|         deserialize = OrderedDict(( | ||||
|             ('cfg', lambda p: self.cfg.update(_load_cfg(p))), | ||||
|             ('model', load_model), | ||||
|             ('vocab', lambda p: self.vocab.from_disk(p)), | ||||
|             ('model', load_model), | ||||
|         )) | ||||
|         util.from_disk(path, deserialize, exclude) | ||||
|         return self | ||||
|  | @ -246,7 +248,7 @@ class TokenVectorEncoder(BaseThincComponent): | |||
|         self.model = model | ||||
|         self.cfg = dict(cfg) | ||||
|         self.cfg['pretrained_dims'] = self.vocab.vectors.data.shape[1] | ||||
|         self.cfg.setdefault('cnn_maxout_pieces', 2) | ||||
|         self.cfg.setdefault('cnn_maxout_pieces', 3) | ||||
| 
 | ||||
|     def __call__(self, doc): | ||||
|         """Add context-sensitive vectors to a `Doc`, e.g. from a CNN or LSTM | ||||
|  | @ -318,7 +320,9 @@ class TokenVectorEncoder(BaseThincComponent): | |||
|         pipeline (list): The pipeline the model is part of. | ||||
|         """ | ||||
|         if self.model is True: | ||||
|             self.cfg['pretrained_dims'] = self.vocab.vectors_length | ||||
|             self.model = self.Model(**self.cfg) | ||||
|             link_vectors_to_models(self.vocab) | ||||
| 
 | ||||
| 
 | ||||
| class NeuralTagger(BaseThincComponent): | ||||
|  | @ -328,6 +332,7 @@ class NeuralTagger(BaseThincComponent): | |||
|         self.model = model | ||||
|         self.cfg = dict(cfg) | ||||
|         self.cfg.setdefault('cnn_maxout_pieces', 2) | ||||
|         self.cfg.setdefault('pretrained_dims', self.vocab.vectors.data.shape[1]) | ||||
| 
 | ||||
|     def __call__(self, doc): | ||||
|         tags = self.predict(([doc], [doc.tensor])) | ||||
|  | @ -424,15 +429,14 @@ class NeuralTagger(BaseThincComponent): | |||
|             vocab.morphology = Morphology(vocab.strings, new_tag_map, | ||||
|                                           vocab.morphology.lemmatizer, | ||||
|                                           exc=vocab.morphology.exc) | ||||
|         token_vector_width = pipeline[0].model.nO | ||||
|         if self.model is True: | ||||
|             self.model = self.Model(self.vocab.morphology.n_tags, token_vector_width, | ||||
|                                     pretrained_dims=self.vocab.vectors_length) | ||||
|             self.cfg['pretrained_dims'] = self.vocab.vectors.data.shape[1] | ||||
|             self.model = self.Model(self.vocab.morphology.n_tags, **self.cfg) | ||||
|             link_vectors_to_models(self.vocab) | ||||
| 
 | ||||
|     @classmethod | ||||
|     def Model(cls, n_tags, token_vector_width, pretrained_dims=0): | ||||
|         return build_tagger_model(n_tags, token_vector_width, | ||||
|                                   pretrained_dims) | ||||
|     def Model(cls, n_tags, **cfg): | ||||
|         return build_tagger_model(n_tags, **cfg) | ||||
| 
 | ||||
|     def use_params(self, params): | ||||
|         with self.model.use_params(params): | ||||
|  | @ -453,8 +457,7 @@ class NeuralTagger(BaseThincComponent): | |||
|             if self.model is True: | ||||
|                 token_vector_width = util.env_opt('token_vector_width', | ||||
|                         self.cfg.get('token_vector_width', 128)) | ||||
|                 self.model = self.Model(self.vocab.morphology.n_tags, token_vector_width, | ||||
|                                         pretrained_dims=self.vocab.vectors_length) | ||||
|                 self.model = self.Model(self.vocab.morphology.n_tags, **self.cfg) | ||||
|             self.model.from_bytes(b) | ||||
| 
 | ||||
|         def load_tag_map(b): | ||||
|  | @ -488,10 +491,7 @@ class NeuralTagger(BaseThincComponent): | |||
|     def from_disk(self, path, **exclude): | ||||
|         def load_model(p): | ||||
|             if self.model is True: | ||||
|                 token_vector_width = util.env_opt('token_vector_width', | ||||
|                         self.cfg.get('token_vector_width', 128)) | ||||
|                 self.model = self.Model(self.vocab.morphology.n_tags, token_vector_width, | ||||
|                                         **self.cfg) | ||||
|                 self.model = self.Model(self.vocab.morphology.n_tags, **self.cfg) | ||||
|             self.model.from_bytes(p.open('rb').read()) | ||||
| 
 | ||||
|         def load_tag_map(p): | ||||
|  | @ -519,6 +519,7 @@ class NeuralLabeller(NeuralTagger): | |||
|         self.model = model | ||||
|         self.cfg = dict(cfg) | ||||
|         self.cfg.setdefault('cnn_maxout_pieces', 2) | ||||
|         self.cfg.setdefault('pretrained_dims', self.vocab.vectors.data.shape[1]) | ||||
| 
 | ||||
|     @property | ||||
|     def labels(self): | ||||
|  | @ -541,13 +542,13 @@ class NeuralLabeller(NeuralTagger): | |||
|                         self.labels[dep] = len(self.labels) | ||||
|         token_vector_width = pipeline[0].model.nO | ||||
|         if self.model is True: | ||||
|             self.model = self.Model(len(self.labels), token_vector_width, | ||||
|                                     pretrained_dims=self.vocab.vectors_length) | ||||
|             self.cfg['pretrained_dims'] = self.vocab.vectors.data.shape[1] | ||||
|             self.model = self.Model(len(self.labels), **self.cfg) | ||||
|             link_vectors_to_models(self.vocab) | ||||
| 
 | ||||
|     @classmethod | ||||
|     def Model(cls, n_tags, token_vector_width, pretrained_dims=0): | ||||
|         return build_tagger_model(n_tags, token_vector_width, | ||||
|                                   pretrained_dims) | ||||
|     def Model(cls, n_tags, **cfg): | ||||
|         return build_tagger_model(n_tags, **cfg) | ||||
| 
 | ||||
|     def get_loss(self, docs, golds, scores): | ||||
|         scores = self.model.ops.flatten(scores) | ||||
|  | @ -623,6 +624,7 @@ class SimilarityHook(BaseThincComponent): | |||
|         """ | ||||
|         if self.model is True: | ||||
|             self.model = self.Model(pipeline[0].model.nO) | ||||
|             link_vectors_to_models(self.vocab) | ||||
| 
 | ||||
| 
 | ||||
| class TextCategorizer(BaseThincComponent): | ||||
|  | @ -696,6 +698,7 @@ class TextCategorizer(BaseThincComponent): | |||
|             self.cfg['pretrained_dims'] = self.vocab.vectors_length | ||||
|             self.model = self.Model(len(self.labels), token_vector_width, | ||||
|                                     **self.cfg) | ||||
|             link_vectors_to_models(self.vocab) | ||||
| 
 | ||||
| 
 | ||||
| cdef class EntityRecognizer(LinearParser): | ||||
|  |  | |||
|  | @ -49,6 +49,7 @@ from ..util import get_async, get_cuda_stream | |||
| from .._ml import zero_init, PrecomputableAffine, PrecomputableMaxouts | ||||
| from .._ml import Tok2Vec, doc2feats, rebatch, fine_tune | ||||
| from .._ml import Residual, drop_layer | ||||
| from .._ml import link_vectors_to_models | ||||
| from ..compat import json_dumps | ||||
| 
 | ||||
| from . import _parse_features | ||||
|  | @ -309,7 +310,7 @@ cdef class Parser: | |||
|             cfg['beam_density'] = util.env_opt('beam_density', 0.0) | ||||
|         if 'pretrained_dims' not in cfg: | ||||
|             cfg['pretrained_dims'] = self.vocab.vectors.data.shape[1] | ||||
|         cfg.setdefault('cnn_maxout_pieces', 2) | ||||
|         cfg.setdefault('cnn_maxout_pieces', 3) | ||||
|         self.cfg = cfg | ||||
|         if 'actions' in self.cfg: | ||||
|             for action, labels in self.cfg.get('actions', {}).items(): | ||||
|  | @ -791,6 +792,7 @@ cdef class Parser: | |||
|         if self.model is True: | ||||
|             cfg['pretrained_dims'] = self.vocab.vectors_length | ||||
|             self.model, cfg = self.Model(self.moves.n_moves, **cfg) | ||||
|             link_vectors_to_models(self.vocab) | ||||
|             self.cfg.update(cfg) | ||||
| 
 | ||||
|     def preprocess_gold(self, docs_golds): | ||||
|  | @ -872,8 +874,7 @@ cdef class Parser: | |||
|         msg = util.from_bytes(bytes_data, deserializers, exclude) | ||||
|         if 'model' not in exclude: | ||||
|             if self.model is True: | ||||
|                 self.model, cfg = self.Model(self.moves.n_moves, | ||||
|                                     pretrained_dims=self.vocab.vectors_length) | ||||
|                 self.model, cfg = self.Model(**self.cfg) | ||||
|                 cfg['pretrained_dims'] = self.vocab.vectors_length | ||||
|             else: | ||||
|                 cfg = {} | ||||
|  |  | |||
|  | @ -27,6 +27,7 @@ from .vectors import Vectors | |||
| from . import util | ||||
| from . import attrs | ||||
| from . import symbols | ||||
| from ._ml import link_vectors_to_models | ||||
| 
 | ||||
| 
 | ||||
| cdef class Vocab: | ||||
|  | @ -323,6 +324,7 @@ cdef class Vocab: | |||
|             self.lexemes_from_bytes(file_.read()) | ||||
|         if self.vectors is not None: | ||||
|             self.vectors.from_disk(path, exclude='strings.json') | ||||
|         link_vectors_to_models(self) | ||||
|         return self | ||||
| 
 | ||||
|     def to_bytes(self, **exclude): | ||||
|  | @ -362,6 +364,7 @@ cdef class Vocab: | |||
|             ('vectors', lambda b: serialize_vectors(b)) | ||||
|         )) | ||||
|         util.from_bytes(bytes_data, setters, exclude) | ||||
|         link_vectors_to_models(self) | ||||
|         return self | ||||
| 
 | ||||
|     def lexemes_to_bytes(self): | ||||
|  | @ -436,6 +439,7 @@ def unpickle_vocab(sstore, morphology, data_dir, | |||
|     vocab.lex_attr_getters = lex_attr_getters | ||||
|     vocab.lexemes_from_bytes(lexemes_data) | ||||
|     vocab.length = length | ||||
|     link_vectors_to_models(vocab) | ||||
|     return vocab | ||||
| 
 | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user