mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 01:48:04 +03:00 
			
		
		
		
	Update tensorizer component
This commit is contained in:
		
							parent
							
								
									2bf21cbe29
								
							
						
					
					
						commit
						17c63906f9
					
				| 
						 | 
					@ -11,7 +11,7 @@ import ujson
 | 
				
			||||||
import msgpack
 | 
					import msgpack
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from thinc.api import chain
 | 
					from thinc.api import chain
 | 
				
			||||||
from thinc.v2v import Affine, Softmax
 | 
					from thinc.v2v import Affine, SELU, Softmax
 | 
				
			||||||
from thinc.t2v import Pooling, max_pool, mean_pool
 | 
					from thinc.t2v import Pooling, max_pool, mean_pool
 | 
				
			||||||
from thinc.neural.util import to_categorical, copy_array
 | 
					from thinc.neural.util import to_categorical, copy_array
 | 
				
			||||||
from thinc.neural._classes.difference import Siamese, CauchySimilarity
 | 
					from thinc.neural._classes.difference import Siamese, CauchySimilarity
 | 
				
			||||||
| 
						 | 
					@ -29,7 +29,7 @@ from .compat import json_dumps
 | 
				
			||||||
from .attrs import POS
 | 
					from .attrs import POS
 | 
				
			||||||
from .parts_of_speech import X
 | 
					from .parts_of_speech import X
 | 
				
			||||||
from ._ml import Tok2Vec, build_text_classifier, build_tagger_model
 | 
					from ._ml import Tok2Vec, build_text_classifier, build_tagger_model
 | 
				
			||||||
from ._ml import link_vectors_to_models
 | 
					from ._ml import link_vectors_to_models, zero_init, flatten
 | 
				
			||||||
from . import util
 | 
					from . import util
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -216,7 +216,7 @@ class Tensorizer(Pipe):
 | 
				
			||||||
    name = 'tensorizer'
 | 
					    name = 'tensorizer'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @classmethod
 | 
					    @classmethod
 | 
				
			||||||
    def Model(cls, width=128, embed_size=4000, **cfg):
 | 
					    def Model(cls, output_size=300, input_size=384, **cfg):
 | 
				
			||||||
        """Create a new statistical model for the class.
 | 
					        """Create a new statistical model for the class.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        width (int): Output size of the model.
 | 
					        width (int): Output size of the model.
 | 
				
			||||||
| 
						 | 
					@ -224,9 +224,11 @@ class Tensorizer(Pipe):
 | 
				
			||||||
        **cfg: Config parameters.
 | 
					        **cfg: Config parameters.
 | 
				
			||||||
        RETURNS (Model): A `thinc.neural.Model` or similar instance.
 | 
					        RETURNS (Model): A `thinc.neural.Model` or similar instance.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        width = util.env_opt('token_vector_width', width)
 | 
					        model = chain(
 | 
				
			||||||
        embed_size = util.env_opt('embed_size', embed_size)
 | 
					                    SELU(output_size, input_size),
 | 
				
			||||||
        return Tok2Vec(width, embed_size, **cfg)
 | 
					                    SELU(output_size, output_size),
 | 
				
			||||||
 | 
					                    zero_init(Affine(output_size, output_size)))
 | 
				
			||||||
 | 
					        return model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(self, vocab, model=True, **cfg):
 | 
					    def __init__(self, vocab, model=True, **cfg):
 | 
				
			||||||
        """Construct a new statistical model. Weights are not allocated on
 | 
					        """Construct a new statistical model. Weights are not allocated on
 | 
				
			||||||
| 
						 | 
					@ -244,6 +246,7 @@ class Tensorizer(Pipe):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        self.vocab = vocab
 | 
					        self.vocab = vocab
 | 
				
			||||||
        self.model = model
 | 
					        self.model = model
 | 
				
			||||||
 | 
					        self.input_models = []
 | 
				
			||||||
        self.cfg = dict(cfg)
 | 
					        self.cfg = dict(cfg)
 | 
				
			||||||
        self.cfg['pretrained_dims'] = self.vocab.vectors.data.shape[1]
 | 
					        self.cfg['pretrained_dims'] = self.vocab.vectors.data.shape[1]
 | 
				
			||||||
        self.cfg.setdefault('cnn_maxout_pieces', 3)
 | 
					        self.cfg.setdefault('cnn_maxout_pieces', 3)
 | 
				
			||||||
| 
						 | 
					@ -269,8 +272,8 @@ class Tensorizer(Pipe):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        for docs in cytoolz.partition_all(batch_size, stream):
 | 
					        for docs in cytoolz.partition_all(batch_size, stream):
 | 
				
			||||||
            docs = list(docs)
 | 
					            docs = list(docs)
 | 
				
			||||||
            tokvecses = self.predict(docs)
 | 
					            tensors = self.predict(docs)
 | 
				
			||||||
            self.set_annotations(docs, tokvecses)
 | 
					            self.set_annotations(docs, tensors)
 | 
				
			||||||
            yield from docs
 | 
					            yield from docs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def predict(self, docs):
 | 
					    def predict(self, docs):
 | 
				
			||||||
| 
						 | 
					@ -279,18 +282,19 @@ class Tensorizer(Pipe):
 | 
				
			||||||
        docs (iterable): A sequence of `Doc` objects.
 | 
					        docs (iterable): A sequence of `Doc` objects.
 | 
				
			||||||
        RETURNS (object): Vector representations for each token in the docs.
 | 
					        RETURNS (object): Vector representations for each token in the docs.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        tokvecs = self.model(docs)
 | 
					        inputs = self.model.ops.flatten([doc.tensor for doc in docs])
 | 
				
			||||||
        return tokvecs
 | 
					        outputs = self.model(inputs)
 | 
				
			||||||
 | 
					        return self.model.ops.unflatten(outputs, [len(d) for d in docs])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def set_annotations(self, docs, tokvecses):
 | 
					    def set_annotations(self, docs, tensors):
 | 
				
			||||||
        """Set the tensor attribute for a batch of documents.
 | 
					        """Set the tensor attribute for a batch of documents.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        docs (iterable): A sequence of `Doc` objects.
 | 
					        docs (iterable): A sequence of `Doc` objects.
 | 
				
			||||||
        tokvecs (object): Vector representation for each token in the docs.
 | 
					        tensors (object): Vector representation for each token in the docs.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        for doc, tokvecs in zip(docs, tokvecses):
 | 
					        for doc, tensor in zip(docs, tensors):
 | 
				
			||||||
            assert tokvecs.shape[0] == len(doc)
 | 
					            assert tensor.shape[0] == len(doc)
 | 
				
			||||||
            doc.tensor = tokvecs
 | 
					            doc.tensor = tensor
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def update(self, docs, golds, state=None, drop=0., sgd=None, losses=None):
 | 
					    def update(self, docs, golds, state=None, drop=0., sgd=None, losses=None):
 | 
				
			||||||
        """Update the model.
 | 
					        """Update the model.
 | 
				
			||||||
| 
						 | 
					@ -303,11 +307,34 @@ class Tensorizer(Pipe):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        if isinstance(docs, Doc):
 | 
					        if isinstance(docs, Doc):
 | 
				
			||||||
            docs = [docs]
 | 
					            docs = [docs]
 | 
				
			||||||
        tokvecs, bp_tokvecs = self.model.begin_update(docs, drop=drop)
 | 
					        inputs = []
 | 
				
			||||||
        return tokvecs, bp_tokvecs
 | 
					        bp_inputs = []
 | 
				
			||||||
 | 
					        for tok2vec in self.input_models:
 | 
				
			||||||
 | 
					            tensor, bp_tensor = tok2vec.begin_update(docs, drop=drop)
 | 
				
			||||||
 | 
					            inputs.append(tensor)
 | 
				
			||||||
 | 
					            bp_inputs.append(bp_tensor)
 | 
				
			||||||
 | 
					        inputs = self.model.ops.xp.hstack(inputs)
 | 
				
			||||||
 | 
					        scores, bp_scores = self.model.begin_update(inputs, drop=drop)
 | 
				
			||||||
 | 
					        loss, d_scores = self.get_loss(docs, golds, scores)
 | 
				
			||||||
 | 
					        d_inputs = bp_scores(d_scores, sgd=sgd)
 | 
				
			||||||
 | 
					        d_inputs = self.model.ops.xp.split(d_inputs, len(self.input_models), axis=1)
 | 
				
			||||||
 | 
					        for d_input, bp_input in zip(d_inputs, bp_inputs): 
 | 
				
			||||||
 | 
					            bp_input(d_input, sgd=sgd)
 | 
				
			||||||
 | 
					        if losses is not None:
 | 
				
			||||||
 | 
					            losses.setdefault(self.name, 0.)
 | 
				
			||||||
 | 
					            losses[self.name] += loss
 | 
				
			||||||
 | 
					        return loss
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_loss(self, docs, golds, scores):
 | 
					    def get_loss(self, docs, golds, prediction):
 | 
				
			||||||
        raise NotImplementedError
 | 
					        target = []
 | 
				
			||||||
 | 
					        i = 0
 | 
				
			||||||
 | 
					        for doc in docs:
 | 
				
			||||||
 | 
					            vectors = self.model.ops.xp.vstack([w.vector for w in doc])
 | 
				
			||||||
 | 
					            target.append(vectors)
 | 
				
			||||||
 | 
					        target = self.model.ops.xp.vstack(target)
 | 
				
			||||||
 | 
					        d_scores = (prediction - target) / prediction.shape[0]
 | 
				
			||||||
 | 
					        loss = (d_scores**2).sum()
 | 
				
			||||||
 | 
					        return loss, d_scores
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def begin_training(self, gold_tuples=tuple(), pipeline=None):
 | 
					    def begin_training(self, gold_tuples=tuple(), pipeline=None):
 | 
				
			||||||
        """Allocate models, pre-process training data and acquire a trainer and
 | 
					        """Allocate models, pre-process training data and acquire a trainer and
 | 
				
			||||||
| 
						 | 
					@ -316,8 +343,13 @@ class Tensorizer(Pipe):
 | 
				
			||||||
        gold_tuples (iterable): Gold-standard training data.
 | 
					        gold_tuples (iterable): Gold-standard training data.
 | 
				
			||||||
        pipeline (list): The pipeline the model is part of.
 | 
					        pipeline (list): The pipeline the model is part of.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
 | 
					        for name, model in pipeline:
 | 
				
			||||||
 | 
					            if getattr(model, 'tok2vec', None):
 | 
				
			||||||
 | 
					                self.input_models.append(model.tok2vec)
 | 
				
			||||||
        if self.model is True:
 | 
					        if self.model is True:
 | 
				
			||||||
            self.cfg['pretrained_dims'] = self.vocab.vectors_length
 | 
					            self.cfg['input_size'] = 384
 | 
				
			||||||
 | 
					            self.cfg['output_size'] = 300
 | 
				
			||||||
 | 
					            #self.cfg['pretrained_dims'] = self.vocab.vectors_length
 | 
				
			||||||
            self.model = self.Model(**self.cfg)
 | 
					            self.model = self.Model(**self.cfg)
 | 
				
			||||||
        link_vectors_to_models(self.vocab)
 | 
					        link_vectors_to_models(self.vocab)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -337,6 +369,13 @@ class Tagger(Pipe):
 | 
				
			||||||
    def labels(self):
 | 
					    def labels(self):
 | 
				
			||||||
        return self.vocab.morphology.tag_names
 | 
					        return self.vocab.morphology.tag_names
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def tok2vec(self):
 | 
				
			||||||
 | 
					        if self.model in (None, True, False):
 | 
				
			||||||
 | 
					            return None
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            return chain(self.model.tok2vec, flatten)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __call__(self, doc):
 | 
					    def __call__(self, doc):
 | 
				
			||||||
        tags, tokvecs = self.predict([doc])
 | 
					        tags, tokvecs = self.predict([doc])
 | 
				
			||||||
        self.set_annotations([doc], tags, tensors=tokvecs)
 | 
					        self.set_annotations([doc], tags, tensors=tokvecs)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user