mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-27 17:54:39 +03:00
db55577c45
* Remove unicode declarations * Remove Python 3.5 and 2.7 from CI * Don't require pathlib * Replace compat helpers * Remove OrderedDict * Use f-strings * Set Cython compiler language level * Fix typo * Re-add OrderedDict for Table * Update setup.cfg * Revert CONTRIBUTING.md * Revert lookups.md * Revert top-level.md * Small adjustments and docs [ci skip]
42 lines
1.3 KiB
Python
42 lines
1.3 KiB
Python
from thinc.api import layerize, wrap, noop, chain, concatenate
|
|
from thinc.v2v import Model
|
|
|
|
|
|
def concatenate_lists(*layers, **kwargs): # pragma: no cover
|
|
"""Compose two or more models `f`, `g`, etc, such that their outputs are
|
|
concatenated, i.e. `concatenate(f, g)(x)` computes `hstack(f(x), g(x))`
|
|
"""
|
|
if not layers:
|
|
return layerize(noop())
|
|
drop_factor = kwargs.get("drop_factor", 1.0)
|
|
ops = layers[0].ops
|
|
layers = [chain(layer, flatten) for layer in layers]
|
|
concat = concatenate(*layers)
|
|
|
|
def concatenate_lists_fwd(Xs, drop=0.0):
|
|
if drop is not None:
|
|
drop *= drop_factor
|
|
lengths = ops.asarray([len(X) for X in Xs], dtype="i")
|
|
flat_y, bp_flat_y = concat.begin_update(Xs, drop=drop)
|
|
ys = ops.unflatten(flat_y, lengths)
|
|
|
|
def concatenate_lists_bwd(d_ys, sgd=None):
|
|
return bp_flat_y(ops.flatten(d_ys), sgd=sgd)
|
|
|
|
return ys, concatenate_lists_bwd
|
|
|
|
model = wrap(concatenate_lists_fwd, concat)
|
|
return model
|
|
|
|
|
|
@layerize
|
|
def flatten(seqs, drop=0.0):
|
|
ops = Model.ops
|
|
lengths = ops.asarray([len(seq) for seq in seqs], dtype="i")
|
|
|
|
def finish_update(d_X, sgd=None):
|
|
return ops.unflatten(d_X, lengths, pad=0)
|
|
|
|
X = ops.flatten(seqs, pad=0)
|
|
return X, finish_update
|