mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 09:57:26 +03:00 
			
		
		
		
	Merge branch 'develop' of https://github.com/explosion/spaCy into develop
This commit is contained in:
		
						commit
						3cf3fa1704
					
				
							
								
								
									
										141
									
								
								spacy/_ml.py
									
									
									
									
									
								
							
							
						
						
									
										141
									
								
								spacy/_ml.py
									
									
									
									
									
								
							| 
						 | 
				
			
			@ -24,7 +24,7 @@ from thinc.linear.linear import LinearModel
 | 
			
		|||
from thinc.api import uniqued, wrap, flatten_add_lengths
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
from .attrs import ID, ORTH, LOWER, NORM, PREFIX, SUFFIX, SHAPE, TAG, DEP
 | 
			
		||||
from .attrs import ID, ORTH, LOWER, NORM, PREFIX, SUFFIX, SHAPE, TAG, DEP, CLUSTER
 | 
			
		||||
from .tokens.doc import Doc
 | 
			
		||||
from . import util
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -469,30 +469,103 @@ def build_tagger_model(nr_class, token_vector_width, **cfg):
 | 
			
		|||
    return model
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@layerize
 | 
			
		||||
def SpacyVectors(docs, drop=0.):
 | 
			
		||||
    xp = get_array_module(docs[0].vocab.vectors.data)
 | 
			
		||||
    width = docs[0].vocab.vectors.data.shape[1]
 | 
			
		||||
    batch = []
 | 
			
		||||
    for doc in docs:
 | 
			
		||||
        indices = numpy.zeros((len(doc),), dtype='i')
 | 
			
		||||
        for i, word in enumerate(doc):
 | 
			
		||||
            if word.orth in doc.vocab.vectors.key2row:
 | 
			
		||||
                indices[i] = doc.vocab.vectors.key2row[word.orth]
 | 
			
		||||
            else:
 | 
			
		||||
                indices[i] = 0
 | 
			
		||||
        vectors = doc.vocab.vectors.data[indices]
 | 
			
		||||
        batch.append(vectors)
 | 
			
		||||
    return batch, None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def foreach(layer, drop_factor=1.0):
 | 
			
		||||
    '''Map a layer across elements in a list'''
 | 
			
		||||
    def foreach_fwd(Xs, drop=0.):
 | 
			
		||||
        drop *= drop_factor
 | 
			
		||||
        ys = []
 | 
			
		||||
        backprops = []
 | 
			
		||||
        for X in Xs:
 | 
			
		||||
            y, bp_y = layer.begin_update(X, drop=drop)
 | 
			
		||||
            ys.append(y)
 | 
			
		||||
            backprops.append(bp_y)
 | 
			
		||||
        def foreach_bwd(d_ys, sgd=None):
 | 
			
		||||
            d_Xs = []
 | 
			
		||||
            for d_y, bp_y in zip(d_ys, backprops):
 | 
			
		||||
                if bp_y is not None and bp_y is not None:
 | 
			
		||||
                    d_Xs.append(d_y, sgd=sgd)
 | 
			
		||||
                else:
 | 
			
		||||
                    d_Xs.append(None)
 | 
			
		||||
            return d_Xs
 | 
			
		||||
        return ys, foreach_bwd
 | 
			
		||||
    model = wrap(foreach_fwd, layer)
 | 
			
		||||
    return model
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def build_text_classifier(nr_class, width=64, **cfg):
 | 
			
		||||
    nr_vector = cfg.get('nr_vector', 200)
 | 
			
		||||
    with Model.define_operators({'>>': chain, '+': add, '|': concatenate, '**': clone}):
 | 
			
		||||
        embed_lower = HashEmbed(width, nr_vector, column=1)
 | 
			
		||||
        embed_prefix = HashEmbed(width//2, nr_vector, column=2)
 | 
			
		||||
        embed_suffix = HashEmbed(width//2, nr_vector, column=3)
 | 
			
		||||
        embed_shape = HashEmbed(width//2, nr_vector, column=4)
 | 
			
		||||
    nr_vector = cfg.get('nr_vector', 5000)
 | 
			
		||||
    with Model.define_operators({'>>': chain, '+': add, '|': concatenate,
 | 
			
		||||
                                 '**': clone}):
 | 
			
		||||
        if cfg.get('low_data'):
 | 
			
		||||
            model = (
 | 
			
		||||
                SpacyVectors
 | 
			
		||||
                >> flatten_add_lengths
 | 
			
		||||
                >> with_getitem(0,
 | 
			
		||||
                    Affine(width, 300)
 | 
			
		||||
                )
 | 
			
		||||
                >> ParametricAttention(width)
 | 
			
		||||
                >> Pooling(sum_pool)
 | 
			
		||||
                >> Residual(ReLu(width, width)) ** 2
 | 
			
		||||
                >> zero_init(Affine(nr_class, width, drop_factor=0.0))
 | 
			
		||||
                >> logistic
 | 
			
		||||
            )
 | 
			
		||||
            return model
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        lower = HashEmbed(width, nr_vector, column=1)
 | 
			
		||||
        prefix = HashEmbed(width//2, nr_vector, column=2)
 | 
			
		||||
        suffix = HashEmbed(width//2, nr_vector, column=3)
 | 
			
		||||
        shape = HashEmbed(width//2, nr_vector, column=4)
 | 
			
		||||
 | 
			
		||||
        trained_vectors = (
 | 
			
		||||
            FeatureExtracter([ORTH, LOWER, PREFIX, SUFFIX, SHAPE, ID])
 | 
			
		||||
            >> with_flatten(
 | 
			
		||||
                uniqued(
 | 
			
		||||
                    (lower | prefix | suffix | shape)
 | 
			
		||||
                    >> LN(Maxout(width, width+(width//2)*3)),
 | 
			
		||||
                    column=0
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        static_vectors = (
 | 
			
		||||
            SpacyVectors
 | 
			
		||||
            >> with_flatten(Affine(width, 300))
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        cnn_model = (
 | 
			
		||||
            FeatureExtracter([ORTH, LOWER, PREFIX, SUFFIX, SHAPE])
 | 
			
		||||
            >> _flatten_add_lengths
 | 
			
		||||
            >> with_getitem(0,
 | 
			
		||||
                uniqued(
 | 
			
		||||
                  (embed_lower | embed_prefix | embed_suffix | embed_shape)
 | 
			
		||||
                  >> Maxout(width, width+(width//2)*3))
 | 
			
		||||
                >> Residual(ExtractWindow(nW=1) >> ReLu(width, width*3))
 | 
			
		||||
                >> Residual(ExtractWindow(nW=1) >> ReLu(width, width*3))
 | 
			
		||||
                >> Residual(ExtractWindow(nW=1) >> ReLu(width, width*3))
 | 
			
		||||
            # TODO Make concatenate support lists
 | 
			
		||||
            concatenate_lists(trained_vectors, static_vectors) 
 | 
			
		||||
            >> with_flatten(
 | 
			
		||||
                LN(Maxout(width, width*2))
 | 
			
		||||
                >> Residual(
 | 
			
		||||
                    (ExtractWindow(nW=1) >> zero_init(Maxout(width, width*3)))
 | 
			
		||||
                ) ** 2, pad=2
 | 
			
		||||
            )
 | 
			
		||||
            >> ParametricAttention(width,)
 | 
			
		||||
            >> flatten_add_lengths
 | 
			
		||||
            >> ParametricAttention(width)
 | 
			
		||||
            >> Pooling(sum_pool)
 | 
			
		||||
            >> ReLu(width, width)
 | 
			
		||||
            >> Residual(zero_init(Maxout(width, width)))
 | 
			
		||||
            >> zero_init(Affine(nr_class, width, drop_factor=0.0))
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        linear_model = (
 | 
			
		||||
            _preprocess_doc
 | 
			
		||||
            >> LinearModel(nr_class, drop_factor=0.)
 | 
			
		||||
| 
						 | 
				
			
			@ -507,3 +580,35 @@ def build_text_classifier(nr_class, width=64, **cfg):
 | 
			
		|||
    model.lsuv = False
 | 
			
		||||
    return model
 | 
			
		||||
 | 
			
		||||
@layerize
 | 
			
		||||
def flatten(seqs, drop=0.):
 | 
			
		||||
    ops = Model.ops
 | 
			
		||||
    lengths = ops.asarray([len(seq) for seq in seqs], dtype='i')
 | 
			
		||||
    def finish_update(d_X, sgd=None):
 | 
			
		||||
        return ops.unflatten(d_X, lengths, pad=0)
 | 
			
		||||
    X = ops.flatten(seqs, pad=0)
 | 
			
		||||
    return X, finish_update
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def concatenate_lists(*layers, **kwargs): # pragma: no cover
 | 
			
		||||
    '''Compose two or more models `f`, `g`, etc, such that their outputs are
 | 
			
		||||
    concatenated, i.e. `concatenate(f, g)(x)` computes `hstack(f(x), g(x))`
 | 
			
		||||
    '''
 | 
			
		||||
    if not layers:
 | 
			
		||||
        return noop()
 | 
			
		||||
    drop_factor = kwargs.get('drop_factor', 1.0)
 | 
			
		||||
    ops = layers[0].ops
 | 
			
		||||
    layers = [chain(layer, flatten) for layer in layers]
 | 
			
		||||
    concat = concatenate(*layers)
 | 
			
		||||
    def concatenate_lists_fwd(Xs, drop=0.):
 | 
			
		||||
        drop *= drop_factor
 | 
			
		||||
        lengths = ops.asarray([len(X) for X in Xs], dtype='i')
 | 
			
		||||
        flat_y, bp_flat_y = concat.begin_update(Xs, drop=drop)
 | 
			
		||||
        ys = ops.unflatten(flat_y, lengths)
 | 
			
		||||
        def concatenate_lists_bwd(d_ys, sgd=None):
 | 
			
		||||
            return bp_flat_y(ops.flatten(d_ys), sgd=sgd)
 | 
			
		||||
        return ys, concatenate_lists_bwd
 | 
			
		||||
    model = wrap(concatenate_lists_fwd, concat)
 | 
			
		||||
    return model
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -3,7 +3,7 @@
 | 
			
		|||
# https://github.com/pypa/warehouse/blob/master/warehouse/__about__.py
 | 
			
		||||
 | 
			
		||||
__title__ = 'spacy-nightly'
 | 
			
		||||
__version__ = '2.0.0a11'
 | 
			
		||||
__version__ = '2.0.0a12'
 | 
			
		||||
__summary__ = 'Industrial-strength Natural Language Processing (NLP) with Python and Cython'
 | 
			
		||||
__uri__ = 'https://spacy.io'
 | 
			
		||||
__author__ = 'Explosion AI'
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -46,6 +46,43 @@ from ._ml import build_text_classifier, build_tagger_model
 | 
			
		|||
from .parts_of_speech import X
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SentenceSegmenter(object):
 | 
			
		||||
    '''A simple spaCy hook, to allow custom sentence boundary detection logic
 | 
			
		||||
    (that doesn't require the dependency parse).
 | 
			
		||||
 | 
			
		||||
    To change the sentence boundary detection strategy, pass a generator
 | 
			
		||||
    function `strategy` on initialization, or assign a new strategy to
 | 
			
		||||
    the .strategy attribute.
 | 
			
		||||
 | 
			
		||||
    Sentence detection strategies should be generators that take `Doc` objects
 | 
			
		||||
    and yield `Span` objects for each sentence.
 | 
			
		||||
    '''
 | 
			
		||||
    name = 'sbd'
 | 
			
		||||
 | 
			
		||||
    def __init__(self, vocab, strategy=None):
 | 
			
		||||
        self.vocab = vocab
 | 
			
		||||
        if strategy is None or strategy == 'on_punct':
 | 
			
		||||
            strategy = self.split_on_punct
 | 
			
		||||
        self.strategy = strategy
 | 
			
		||||
 | 
			
		||||
    def __call__(self, doc):
 | 
			
		||||
        doc.user_hooks['sents'] = self.strategy
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def split_on_punct(doc):
 | 
			
		||||
        start = 0
 | 
			
		||||
        seen_period = False
 | 
			
		||||
        for i, word in enumerate(doc):
 | 
			
		||||
            if seen_period and not word.is_punct:
 | 
			
		||||
                yield doc[start : word.i]
 | 
			
		||||
                start = word.i
 | 
			
		||||
                seen_period = False
 | 
			
		||||
            elif word.text in ['.', '!', '?']:
 | 
			
		||||
                seen_period = True
 | 
			
		||||
        if start < len(doc):
 | 
			
		||||
            yield doc[start : len(doc)]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BaseThincComponent(object):
 | 
			
		||||
    name = None
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -91,15 +128,20 @@ class BaseThincComponent(object):
 | 
			
		|||
 | 
			
		||||
    def to_bytes(self, **exclude):
 | 
			
		||||
        serialize = OrderedDict((
 | 
			
		||||
            ('cfg', lambda: json_dumps(self.cfg)),
 | 
			
		||||
            ('model', lambda: self.model.to_bytes()),
 | 
			
		||||
            ('vocab', lambda: self.vocab.to_bytes())
 | 
			
		||||
        ))
 | 
			
		||||
        return util.to_bytes(serialize, exclude)
 | 
			
		||||
 | 
			
		||||
    def from_bytes(self, bytes_data, **exclude):
 | 
			
		||||
        def load_model(b):
 | 
			
		||||
            if self.model is True:
 | 
			
		||||
            self.model = self.Model()
 | 
			
		||||
                self.model = self.Model(**self.cfg)
 | 
			
		||||
            self.model.from_bytes(b)
 | 
			
		||||
 | 
			
		||||
        deserialize = OrderedDict((
 | 
			
		||||
            ('cfg', lambda b: self.cfg.update(ujson.loads(b))),
 | 
			
		||||
            ('model', lambda b: self.model.from_bytes(b)),
 | 
			
		||||
            ('vocab', lambda b: self.vocab.from_bytes(b))
 | 
			
		||||
        ))
 | 
			
		||||
| 
						 | 
				
			
			@ -108,19 +150,22 @@ class BaseThincComponent(object):
 | 
			
		|||
 | 
			
		||||
    def to_disk(self, path, **exclude):
 | 
			
		||||
        serialize = OrderedDict((
 | 
			
		||||
            ('cfg', lambda p: p.open('w').write(json_dumps(self.cfg))),
 | 
			
		||||
            ('model', lambda p: p.open('wb').write(self.model.to_bytes())),
 | 
			
		||||
            ('vocab', lambda p: self.vocab.to_disk(p)),
 | 
			
		||||
            ('cfg', lambda p: p.open('w').write(json_dumps(self.cfg)))
 | 
			
		||||
            ('vocab', lambda p: self.vocab.to_disk(p))
 | 
			
		||||
        ))
 | 
			
		||||
        util.to_disk(path, serialize, exclude)
 | 
			
		||||
 | 
			
		||||
    def from_disk(self, path, **exclude):
 | 
			
		||||
        def load_model(p):
 | 
			
		||||
            if self.model is True:
 | 
			
		||||
            self.model = self.Model()
 | 
			
		||||
                self.model = self.Model(**self.cfg)
 | 
			
		||||
            self.model.from_bytes(p.open('rb').read())
 | 
			
		||||
 | 
			
		||||
        deserialize = OrderedDict((
 | 
			
		||||
            ('model', lambda p: self.model.from_bytes(p.open('rb').read())),
 | 
			
		||||
            ('cfg', lambda p: self.cfg.update(_load_cfg(p))),
 | 
			
		||||
            ('model', load_model),
 | 
			
		||||
            ('vocab', lambda p: self.vocab.from_disk(p)),
 | 
			
		||||
            ('cfg', lambda p: self.cfg.update(_load_cfg(p)))
 | 
			
		||||
        ))
 | 
			
		||||
        util.from_disk(path, deserialize, exclude)
 | 
			
		||||
        return self
 | 
			
		||||
| 
						 | 
				
			
			@ -601,12 +646,13 @@ class TextCategorizer(BaseThincComponent):
 | 
			
		|||
        return mean_square_error, d_scores
 | 
			
		||||
 | 
			
		||||
    def begin_training(self, gold_tuples=tuple(), pipeline=None):
 | 
			
		||||
        if pipeline:
 | 
			
		||||
        if pipeline and getattr(pipeline[0], 'name', None) == 'tensorizer':
 | 
			
		||||
            token_vector_width = pipeline[0].model.nO
 | 
			
		||||
        else:
 | 
			
		||||
            token_vector_width = 64
 | 
			
		||||
        if self.model is True:
 | 
			
		||||
            self.model = self.Model(len(self.labels), token_vector_width)
 | 
			
		||||
            self.model = self.Model(len(self.labels), token_vector_width,
 | 
			
		||||
                                    **self.cfg)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
cdef class EntityRecognizer(LinearParser):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -170,7 +170,7 @@ def get_model_meta(path):
 | 
			
		|||
    meta = read_json(meta_path)
 | 
			
		||||
    for setting in ['lang', 'name', 'version']:
 | 
			
		||||
        if setting not in meta or not meta[setting]:
 | 
			
		||||
            raise ValueError('No valid '%s' setting found in model meta.json' % setting)
 | 
			
		||||
            raise ValueError("No valid '%s' setting found in model meta.json" % setting)
 | 
			
		||||
    return meta
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -90,6 +90,33 @@ cdef class Vectors:
 | 
			
		|||
    def most_similar(self, key):
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
    def from_glove(self, path):
 | 
			
		||||
        '''Load GloVe vectors from a directory. Assumes binary format,
 | 
			
		||||
        that the vocab is in a vocab.txt, and that vectors are named
 | 
			
		||||
        vectors.{size}.[fd].bin, e.g. vectors.128.f.bin for 128d float32
 | 
			
		||||
        vectors, vectors.300.d.bin for 300d float64 (double) vectors, etc.
 | 
			
		||||
        By default GloVe outputs 64-bit vectors.'''
 | 
			
		||||
        path = util.ensure_path(path)
 | 
			
		||||
        for name in path.iterdir():
 | 
			
		||||
            if name.parts[-1].startswith('vectors'):
 | 
			
		||||
                _, dims, dtype, _2 = name.parts[-1].split('.')
 | 
			
		||||
                self.width = int(dims)
 | 
			
		||||
                break
 | 
			
		||||
        else:
 | 
			
		||||
            raise IOError("Expected file named e.g. vectors.128.f.bin")
 | 
			
		||||
        bin_loc = path / 'vectors.{dims}.{dtype}.bin'.format(dims=dims,
 | 
			
		||||
                                                             dtype=dtype)
 | 
			
		||||
        with bin_loc.open('rb') as file_:
 | 
			
		||||
            self.data = numpy.fromfile(file_, dtype='float64')
 | 
			
		||||
            self.data = numpy.ascontiguousarray(self.data, dtype='float32')
 | 
			
		||||
        n = 0
 | 
			
		||||
        with (path / 'vocab.txt').open('r') as file_:
 | 
			
		||||
            for line in file_:
 | 
			
		||||
                self.add(line.strip())
 | 
			
		||||
                n += 1
 | 
			
		||||
        if (self.data.size % self.width) == 0:
 | 
			
		||||
            self.data
 | 
			
		||||
 | 
			
		||||
    def to_disk(self, path, **exclude):
 | 
			
		||||
        serializers = OrderedDict((
 | 
			
		||||
            ('vectors', lambda p: numpy.save(p.open('wb'), self.data, allow_pickle=False)),
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user