mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-24 08:14:15 +03:00
Cache features in doc2feats
This commit is contained in:
parent
39ea38c4b1
commit
711ad5edc4
11
spacy/_ml.py
11
spacy/_ml.py
|
@ -10,7 +10,9 @@ from thinc.neural._classes.resnet import Residual
|
||||||
from thinc import describe
|
from thinc import describe
|
||||||
from thinc.describe import Dimension, Synapses, Biases, Gradient
|
from thinc.describe import Dimension, Synapses, Biases, Gradient
|
||||||
from thinc.neural._classes.affine import _set_dimensions_if_needed
|
from thinc.neural._classes.affine import _set_dimensions_if_needed
|
||||||
|
|
||||||
from .attrs import ID, LOWER, PREFIX, SUFFIX, SHAPE, TAG, DEP
|
from .attrs import ID, LOWER, PREFIX, SUFFIX, SHAPE, TAG, DEP
|
||||||
|
from .tokens.doc import Doc
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
|
|
||||||
|
@ -167,8 +169,13 @@ def zero_init(model):
|
||||||
def doc2feats(cols=None):
|
def doc2feats(cols=None):
|
||||||
cols = [ID, LOWER, PREFIX, SUFFIX, SHAPE]
|
cols = [ID, LOWER, PREFIX, SUFFIX, SHAPE]
|
||||||
def forward(docs, drop=0.):
|
def forward(docs, drop=0.):
|
||||||
feats = [doc.to_array(cols) for doc in docs]
|
feats = []
|
||||||
feats = [model.ops.asarray(f, dtype='uint64') for f in feats]
|
for doc in docs:
|
||||||
|
if 'cached_feats' not in doc.user_data:
|
||||||
|
doc.user_data['cached_feats'] = model.ops.asarray(
|
||||||
|
doc.to_array(cols),
|
||||||
|
dtype='uint64')
|
||||||
|
feats.append(doc.user_data['cached_feats'])
|
||||||
return feats, None
|
return feats, None
|
||||||
model = layerize(forward)
|
model = layerize(forward)
|
||||||
model.cols = cols
|
model.cols = cols
|
||||||
|
|
Loading…
Reference in New Issue
Block a user