mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-13 18:56:36 +03:00
Merge branch 'develop' of https://github.com/explosion/spaCy into develop
This commit is contained in:
commit
3cf3fa1704
141
spacy/_ml.py
141
spacy/_ml.py
|
@ -24,7 +24,7 @@ from thinc.linear.linear import LinearModel
|
||||||
from thinc.api import uniqued, wrap, flatten_add_lengths
|
from thinc.api import uniqued, wrap, flatten_add_lengths
|
||||||
|
|
||||||
|
|
||||||
from .attrs import ID, ORTH, LOWER, NORM, PREFIX, SUFFIX, SHAPE, TAG, DEP
|
from .attrs import ID, ORTH, LOWER, NORM, PREFIX, SUFFIX, SHAPE, TAG, DEP, CLUSTER
|
||||||
from .tokens.doc import Doc
|
from .tokens.doc import Doc
|
||||||
from . import util
|
from . import util
|
||||||
|
|
||||||
|
@ -469,30 +469,103 @@ def build_tagger_model(nr_class, token_vector_width, **cfg):
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@layerize
|
||||||
|
def SpacyVectors(docs, drop=0.):
|
||||||
|
xp = get_array_module(docs[0].vocab.vectors.data)
|
||||||
|
width = docs[0].vocab.vectors.data.shape[1]
|
||||||
|
batch = []
|
||||||
|
for doc in docs:
|
||||||
|
indices = numpy.zeros((len(doc),), dtype='i')
|
||||||
|
for i, word in enumerate(doc):
|
||||||
|
if word.orth in doc.vocab.vectors.key2row:
|
||||||
|
indices[i] = doc.vocab.vectors.key2row[word.orth]
|
||||||
|
else:
|
||||||
|
indices[i] = 0
|
||||||
|
vectors = doc.vocab.vectors.data[indices]
|
||||||
|
batch.append(vectors)
|
||||||
|
return batch, None
|
||||||
|
|
||||||
|
|
||||||
|
def foreach(layer, drop_factor=1.0):
|
||||||
|
'''Map a layer across elements in a list'''
|
||||||
|
def foreach_fwd(Xs, drop=0.):
|
||||||
|
drop *= drop_factor
|
||||||
|
ys = []
|
||||||
|
backprops = []
|
||||||
|
for X in Xs:
|
||||||
|
y, bp_y = layer.begin_update(X, drop=drop)
|
||||||
|
ys.append(y)
|
||||||
|
backprops.append(bp_y)
|
||||||
|
def foreach_bwd(d_ys, sgd=None):
|
||||||
|
d_Xs = []
|
||||||
|
for d_y, bp_y in zip(d_ys, backprops):
|
||||||
|
if bp_y is not None and bp_y is not None:
|
||||||
|
d_Xs.append(d_y, sgd=sgd)
|
||||||
|
else:
|
||||||
|
d_Xs.append(None)
|
||||||
|
return d_Xs
|
||||||
|
return ys, foreach_bwd
|
||||||
|
model = wrap(foreach_fwd, layer)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
def build_text_classifier(nr_class, width=64, **cfg):
|
def build_text_classifier(nr_class, width=64, **cfg):
|
||||||
nr_vector = cfg.get('nr_vector', 200)
|
nr_vector = cfg.get('nr_vector', 5000)
|
||||||
with Model.define_operators({'>>': chain, '+': add, '|': concatenate, '**': clone}):
|
with Model.define_operators({'>>': chain, '+': add, '|': concatenate,
|
||||||
embed_lower = HashEmbed(width, nr_vector, column=1)
|
'**': clone}):
|
||||||
embed_prefix = HashEmbed(width//2, nr_vector, column=2)
|
if cfg.get('low_data'):
|
||||||
embed_suffix = HashEmbed(width//2, nr_vector, column=3)
|
model = (
|
||||||
embed_shape = HashEmbed(width//2, nr_vector, column=4)
|
SpacyVectors
|
||||||
|
>> flatten_add_lengths
|
||||||
|
>> with_getitem(0,
|
||||||
|
Affine(width, 300)
|
||||||
|
)
|
||||||
|
>> ParametricAttention(width)
|
||||||
|
>> Pooling(sum_pool)
|
||||||
|
>> Residual(ReLu(width, width)) ** 2
|
||||||
|
>> zero_init(Affine(nr_class, width, drop_factor=0.0))
|
||||||
|
>> logistic
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
lower = HashEmbed(width, nr_vector, column=1)
|
||||||
|
prefix = HashEmbed(width//2, nr_vector, column=2)
|
||||||
|
suffix = HashEmbed(width//2, nr_vector, column=3)
|
||||||
|
shape = HashEmbed(width//2, nr_vector, column=4)
|
||||||
|
|
||||||
|
trained_vectors = (
|
||||||
|
FeatureExtracter([ORTH, LOWER, PREFIX, SUFFIX, SHAPE, ID])
|
||||||
|
>> with_flatten(
|
||||||
|
uniqued(
|
||||||
|
(lower | prefix | suffix | shape)
|
||||||
|
>> LN(Maxout(width, width+(width//2)*3)),
|
||||||
|
column=0
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
static_vectors = (
|
||||||
|
SpacyVectors
|
||||||
|
>> with_flatten(Affine(width, 300))
|
||||||
|
)
|
||||||
|
|
||||||
cnn_model = (
|
cnn_model = (
|
||||||
FeatureExtracter([ORTH, LOWER, PREFIX, SUFFIX, SHAPE])
|
# TODO Make concatenate support lists
|
||||||
>> _flatten_add_lengths
|
concatenate_lists(trained_vectors, static_vectors)
|
||||||
>> with_getitem(0,
|
>> with_flatten(
|
||||||
uniqued(
|
LN(Maxout(width, width*2))
|
||||||
(embed_lower | embed_prefix | embed_suffix | embed_shape)
|
>> Residual(
|
||||||
>> Maxout(width, width+(width//2)*3))
|
(ExtractWindow(nW=1) >> zero_init(Maxout(width, width*3)))
|
||||||
>> Residual(ExtractWindow(nW=1) >> ReLu(width, width*3))
|
) ** 2, pad=2
|
||||||
>> Residual(ExtractWindow(nW=1) >> ReLu(width, width*3))
|
|
||||||
>> Residual(ExtractWindow(nW=1) >> ReLu(width, width*3))
|
|
||||||
)
|
)
|
||||||
>> ParametricAttention(width,)
|
>> flatten_add_lengths
|
||||||
|
>> ParametricAttention(width)
|
||||||
>> Pooling(sum_pool)
|
>> Pooling(sum_pool)
|
||||||
>> ReLu(width, width)
|
>> Residual(zero_init(Maxout(width, width)))
|
||||||
>> zero_init(Affine(nr_class, width, drop_factor=0.0))
|
>> zero_init(Affine(nr_class, width, drop_factor=0.0))
|
||||||
)
|
)
|
||||||
|
|
||||||
linear_model = (
|
linear_model = (
|
||||||
_preprocess_doc
|
_preprocess_doc
|
||||||
>> LinearModel(nr_class, drop_factor=0.)
|
>> LinearModel(nr_class, drop_factor=0.)
|
||||||
|
@ -507,3 +580,35 @@ def build_text_classifier(nr_class, width=64, **cfg):
|
||||||
model.lsuv = False
|
model.lsuv = False
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
@layerize
|
||||||
|
def flatten(seqs, drop=0.):
|
||||||
|
ops = Model.ops
|
||||||
|
lengths = ops.asarray([len(seq) for seq in seqs], dtype='i')
|
||||||
|
def finish_update(d_X, sgd=None):
|
||||||
|
return ops.unflatten(d_X, lengths, pad=0)
|
||||||
|
X = ops.flatten(seqs, pad=0)
|
||||||
|
return X, finish_update
|
||||||
|
|
||||||
|
|
||||||
|
def concatenate_lists(*layers, **kwargs): # pragma: no cover
|
||||||
|
'''Compose two or more models `f`, `g`, etc, such that their outputs are
|
||||||
|
concatenated, i.e. `concatenate(f, g)(x)` computes `hstack(f(x), g(x))`
|
||||||
|
'''
|
||||||
|
if not layers:
|
||||||
|
return noop()
|
||||||
|
drop_factor = kwargs.get('drop_factor', 1.0)
|
||||||
|
ops = layers[0].ops
|
||||||
|
layers = [chain(layer, flatten) for layer in layers]
|
||||||
|
concat = concatenate(*layers)
|
||||||
|
def concatenate_lists_fwd(Xs, drop=0.):
|
||||||
|
drop *= drop_factor
|
||||||
|
lengths = ops.asarray([len(X) for X in Xs], dtype='i')
|
||||||
|
flat_y, bp_flat_y = concat.begin_update(Xs, drop=drop)
|
||||||
|
ys = ops.unflatten(flat_y, lengths)
|
||||||
|
def concatenate_lists_bwd(d_ys, sgd=None):
|
||||||
|
return bp_flat_y(ops.flatten(d_ys), sgd=sgd)
|
||||||
|
return ys, concatenate_lists_bwd
|
||||||
|
model = wrap(concatenate_lists_fwd, concat)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
# https://github.com/pypa/warehouse/blob/master/warehouse/__about__.py
|
# https://github.com/pypa/warehouse/blob/master/warehouse/__about__.py
|
||||||
|
|
||||||
__title__ = 'spacy-nightly'
|
__title__ = 'spacy-nightly'
|
||||||
__version__ = '2.0.0a11'
|
__version__ = '2.0.0a12'
|
||||||
__summary__ = 'Industrial-strength Natural Language Processing (NLP) with Python and Cython'
|
__summary__ = 'Industrial-strength Natural Language Processing (NLP) with Python and Cython'
|
||||||
__uri__ = 'https://spacy.io'
|
__uri__ = 'https://spacy.io'
|
||||||
__author__ = 'Explosion AI'
|
__author__ = 'Explosion AI'
|
||||||
|
|
|
@ -46,6 +46,43 @@ from ._ml import build_text_classifier, build_tagger_model
|
||||||
from .parts_of_speech import X
|
from .parts_of_speech import X
|
||||||
|
|
||||||
|
|
||||||
|
class SentenceSegmenter(object):
|
||||||
|
'''A simple spaCy hook, to allow custom sentence boundary detection logic
|
||||||
|
(that doesn't require the dependency parse).
|
||||||
|
|
||||||
|
To change the sentence boundary detection strategy, pass a generator
|
||||||
|
function `strategy` on initialization, or assign a new strategy to
|
||||||
|
the .strategy attribute.
|
||||||
|
|
||||||
|
Sentence detection strategies should be generators that take `Doc` objects
|
||||||
|
and yield `Span` objects for each sentence.
|
||||||
|
'''
|
||||||
|
name = 'sbd'
|
||||||
|
|
||||||
|
def __init__(self, vocab, strategy=None):
|
||||||
|
self.vocab = vocab
|
||||||
|
if strategy is None or strategy == 'on_punct':
|
||||||
|
strategy = self.split_on_punct
|
||||||
|
self.strategy = strategy
|
||||||
|
|
||||||
|
def __call__(self, doc):
|
||||||
|
doc.user_hooks['sents'] = self.strategy
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def split_on_punct(doc):
|
||||||
|
start = 0
|
||||||
|
seen_period = False
|
||||||
|
for i, word in enumerate(doc):
|
||||||
|
if seen_period and not word.is_punct:
|
||||||
|
yield doc[start : word.i]
|
||||||
|
start = word.i
|
||||||
|
seen_period = False
|
||||||
|
elif word.text in ['.', '!', '?']:
|
||||||
|
seen_period = True
|
||||||
|
if start < len(doc):
|
||||||
|
yield doc[start : len(doc)]
|
||||||
|
|
||||||
|
|
||||||
class BaseThincComponent(object):
|
class BaseThincComponent(object):
|
||||||
name = None
|
name = None
|
||||||
|
|
||||||
|
@ -91,15 +128,20 @@ class BaseThincComponent(object):
|
||||||
|
|
||||||
def to_bytes(self, **exclude):
|
def to_bytes(self, **exclude):
|
||||||
serialize = OrderedDict((
|
serialize = OrderedDict((
|
||||||
|
('cfg', lambda: json_dumps(self.cfg)),
|
||||||
('model', lambda: self.model.to_bytes()),
|
('model', lambda: self.model.to_bytes()),
|
||||||
('vocab', lambda: self.vocab.to_bytes())
|
('vocab', lambda: self.vocab.to_bytes())
|
||||||
))
|
))
|
||||||
return util.to_bytes(serialize, exclude)
|
return util.to_bytes(serialize, exclude)
|
||||||
|
|
||||||
def from_bytes(self, bytes_data, **exclude):
|
def from_bytes(self, bytes_data, **exclude):
|
||||||
if self.model is True:
|
def load_model(b):
|
||||||
self.model = self.Model()
|
if self.model is True:
|
||||||
|
self.model = self.Model(**self.cfg)
|
||||||
|
self.model.from_bytes(b)
|
||||||
|
|
||||||
deserialize = OrderedDict((
|
deserialize = OrderedDict((
|
||||||
|
('cfg', lambda b: self.cfg.update(ujson.loads(b))),
|
||||||
('model', lambda b: self.model.from_bytes(b)),
|
('model', lambda b: self.model.from_bytes(b)),
|
||||||
('vocab', lambda b: self.vocab.from_bytes(b))
|
('vocab', lambda b: self.vocab.from_bytes(b))
|
||||||
))
|
))
|
||||||
|
@ -108,19 +150,22 @@ class BaseThincComponent(object):
|
||||||
|
|
||||||
def to_disk(self, path, **exclude):
|
def to_disk(self, path, **exclude):
|
||||||
serialize = OrderedDict((
|
serialize = OrderedDict((
|
||||||
|
('cfg', lambda p: p.open('w').write(json_dumps(self.cfg))),
|
||||||
('model', lambda p: p.open('wb').write(self.model.to_bytes())),
|
('model', lambda p: p.open('wb').write(self.model.to_bytes())),
|
||||||
('vocab', lambda p: self.vocab.to_disk(p)),
|
('vocab', lambda p: self.vocab.to_disk(p))
|
||||||
('cfg', lambda p: p.open('w').write(json_dumps(self.cfg)))
|
|
||||||
))
|
))
|
||||||
util.to_disk(path, serialize, exclude)
|
util.to_disk(path, serialize, exclude)
|
||||||
|
|
||||||
def from_disk(self, path, **exclude):
|
def from_disk(self, path, **exclude):
|
||||||
if self.model is True:
|
def load_model(p):
|
||||||
self.model = self.Model()
|
if self.model is True:
|
||||||
|
self.model = self.Model(**self.cfg)
|
||||||
|
self.model.from_bytes(p.open('rb').read())
|
||||||
|
|
||||||
deserialize = OrderedDict((
|
deserialize = OrderedDict((
|
||||||
('model', lambda p: self.model.from_bytes(p.open('rb').read())),
|
('cfg', lambda p: self.cfg.update(_load_cfg(p))),
|
||||||
|
('model', load_model),
|
||||||
('vocab', lambda p: self.vocab.from_disk(p)),
|
('vocab', lambda p: self.vocab.from_disk(p)),
|
||||||
('cfg', lambda p: self.cfg.update(_load_cfg(p)))
|
|
||||||
))
|
))
|
||||||
util.from_disk(path, deserialize, exclude)
|
util.from_disk(path, deserialize, exclude)
|
||||||
return self
|
return self
|
||||||
|
@ -601,12 +646,13 @@ class TextCategorizer(BaseThincComponent):
|
||||||
return mean_square_error, d_scores
|
return mean_square_error, d_scores
|
||||||
|
|
||||||
def begin_training(self, gold_tuples=tuple(), pipeline=None):
|
def begin_training(self, gold_tuples=tuple(), pipeline=None):
|
||||||
if pipeline:
|
if pipeline and getattr(pipeline[0], 'name', None) == 'tensorizer':
|
||||||
token_vector_width = pipeline[0].model.nO
|
token_vector_width = pipeline[0].model.nO
|
||||||
else:
|
else:
|
||||||
token_vector_width = 64
|
token_vector_width = 64
|
||||||
if self.model is True:
|
if self.model is True:
|
||||||
self.model = self.Model(len(self.labels), token_vector_width)
|
self.model = self.Model(len(self.labels), token_vector_width,
|
||||||
|
**self.cfg)
|
||||||
|
|
||||||
|
|
||||||
cdef class EntityRecognizer(LinearParser):
|
cdef class EntityRecognizer(LinearParser):
|
||||||
|
|
|
@ -170,7 +170,7 @@ def get_model_meta(path):
|
||||||
meta = read_json(meta_path)
|
meta = read_json(meta_path)
|
||||||
for setting in ['lang', 'name', 'version']:
|
for setting in ['lang', 'name', 'version']:
|
||||||
if setting not in meta or not meta[setting]:
|
if setting not in meta or not meta[setting]:
|
||||||
raise ValueError('No valid '%s' setting found in model meta.json' % setting)
|
raise ValueError("No valid '%s' setting found in model meta.json" % setting)
|
||||||
return meta
|
return meta
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -90,6 +90,33 @@ cdef class Vectors:
|
||||||
def most_similar(self, key):
|
def most_similar(self, key):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def from_glove(self, path):
|
||||||
|
'''Load GloVe vectors from a directory. Assumes binary format,
|
||||||
|
that the vocab is in a vocab.txt, and that vectors are named
|
||||||
|
vectors.{size}.[fd].bin, e.g. vectors.128.f.bin for 128d float32
|
||||||
|
vectors, vectors.300.d.bin for 300d float64 (double) vectors, etc.
|
||||||
|
By default GloVe outputs 64-bit vectors.'''
|
||||||
|
path = util.ensure_path(path)
|
||||||
|
for name in path.iterdir():
|
||||||
|
if name.parts[-1].startswith('vectors'):
|
||||||
|
_, dims, dtype, _2 = name.parts[-1].split('.')
|
||||||
|
self.width = int(dims)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
raise IOError("Expected file named e.g. vectors.128.f.bin")
|
||||||
|
bin_loc = path / 'vectors.{dims}.{dtype}.bin'.format(dims=dims,
|
||||||
|
dtype=dtype)
|
||||||
|
with bin_loc.open('rb') as file_:
|
||||||
|
self.data = numpy.fromfile(file_, dtype='float64')
|
||||||
|
self.data = numpy.ascontiguousarray(self.data, dtype='float32')
|
||||||
|
n = 0
|
||||||
|
with (path / 'vocab.txt').open('r') as file_:
|
||||||
|
for line in file_:
|
||||||
|
self.add(line.strip())
|
||||||
|
n += 1
|
||||||
|
if (self.data.size % self.width) == 0:
|
||||||
|
self.data
|
||||||
|
|
||||||
def to_disk(self, path, **exclude):
|
def to_disk(self, path, **exclude):
|
||||||
serializers = OrderedDict((
|
serializers = OrderedDict((
|
||||||
('vectors', lambda p: numpy.save(p.open('wb'), self.data, allow_pickle=False)),
|
('vectors', lambda p: numpy.save(p.open('wb'), self.data, allow_pickle=False)),
|
||||||
|
|
Loading…
Reference in New Issue
Block a user