mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 18:26:30 +03:00
Update text classification model
This commit is contained in:
parent
7c7fac9337
commit
523b0df2c9
25
spacy/_ml.py
25
spacy/_ml.py
|
@ -4,18 +4,22 @@ from thinc.neural import Model, Maxout, Softmax, Affine
|
||||||
from thinc.neural._classes.hash_embed import HashEmbed
|
from thinc.neural._classes.hash_embed import HashEmbed
|
||||||
from thinc.neural.ops import NumpyOps, CupyOps
|
from thinc.neural.ops import NumpyOps, CupyOps
|
||||||
from thinc.neural.util import get_array_module
|
from thinc.neural.util import get_array_module
|
||||||
|
import random
|
||||||
|
|
||||||
from thinc.neural._classes.convolution import ExtractWindow
|
from thinc.neural._classes.convolution import ExtractWindow
|
||||||
from thinc.neural._classes.static_vectors import StaticVectors
|
from thinc.neural._classes.static_vectors import StaticVectors
|
||||||
from thinc.neural._classes.batchnorm import BatchNorm
|
from thinc.neural._classes.batchnorm import BatchNorm
|
||||||
from thinc.neural._classes.resnet import Residual
|
from thinc.neural._classes.resnet import Residual
|
||||||
from thinc.neural import ReLu
|
from thinc.neural import ReLu
|
||||||
|
from thinc.neural._classes.selu import SELU
|
||||||
from thinc import describe
|
from thinc import describe
|
||||||
from thinc.describe import Dimension, Synapses, Biases, Gradient
|
from thinc.describe import Dimension, Synapses, Biases, Gradient
|
||||||
from thinc.neural._classes.affine import _set_dimensions_if_needed
|
from thinc.neural._classes.affine import _set_dimensions_if_needed
|
||||||
from thinc.api import FeatureExtracter, with_getitem
|
from thinc.api import FeatureExtracter, with_getitem
|
||||||
from thinc.neural.pooling import Pooling, max_pool, mean_pool
|
from thinc.neural.pooling import Pooling, max_pool, mean_pool, sum_pool
|
||||||
|
from thinc.neural._classes.attention import ParametricAttention
|
||||||
from thinc.linear.linear import LinearModel
|
from thinc.linear.linear import LinearModel
|
||||||
|
from thinc.api import uniqued, wrap
|
||||||
|
|
||||||
from .attrs import ID, ORTH, LOWER, NORM, PREFIX, SUFFIX, SHAPE, TAG, DEP
|
from .attrs import ID, ORTH, LOWER, NORM, PREFIX, SUFFIX, SHAPE, TAG, DEP
|
||||||
from .tokens.doc import Doc
|
from .tokens.doc import Doc
|
||||||
|
@ -367,7 +371,7 @@ def preprocess_doc(docs, drop=0.):
|
||||||
|
|
||||||
|
|
||||||
def build_text_classifier(nr_class, width=64, **cfg):
|
def build_text_classifier(nr_class, width=64, **cfg):
|
||||||
nr_vector = cfg.get('nr_vector', 1000)
|
nr_vector = cfg.get('nr_vector', 200)
|
||||||
with Model.define_operators({'>>': chain, '+': add, '|': concatenate, '**': clone}):
|
with Model.define_operators({'>>': chain, '+': add, '|': concatenate, '**': clone}):
|
||||||
embed_lower = HashEmbed(width, nr_vector, column=1)
|
embed_lower = HashEmbed(width, nr_vector, column=1)
|
||||||
embed_prefix = HashEmbed(width//2, nr_vector, column=2)
|
embed_prefix = HashEmbed(width//2, nr_vector, column=2)
|
||||||
|
@ -378,25 +382,26 @@ def build_text_classifier(nr_class, width=64, **cfg):
|
||||||
FeatureExtracter([ORTH, LOWER, PREFIX, SUFFIX, SHAPE])
|
FeatureExtracter([ORTH, LOWER, PREFIX, SUFFIX, SHAPE])
|
||||||
>> _flatten_add_lengths
|
>> _flatten_add_lengths
|
||||||
>> with_getitem(0,
|
>> with_getitem(0,
|
||||||
|
uniqued(
|
||||||
(embed_lower | embed_prefix | embed_suffix | embed_shape)
|
(embed_lower | embed_prefix | embed_suffix | embed_shape)
|
||||||
>> Maxout(width, width+(width//2)*3)
|
>> Maxout(width, width+(width//2)*3))
|
||||||
>> Residual(ExtractWindow(nW=1) >> ReLu(width, width*3))
|
>> Residual(ExtractWindow(nW=1) >> ReLu(width, width*3))
|
||||||
>> Residual(ExtractWindow(nW=1) >> ReLu(width, width*3))
|
>> Residual(ExtractWindow(nW=1) >> ReLu(width, width*3))
|
||||||
>> Residual(ExtractWindow(nW=1) >> ReLu(width, width*3))
|
>> Residual(ExtractWindow(nW=1) >> ReLu(width, width*3))
|
||||||
)
|
)
|
||||||
>> Pooling(mean_pool, max_pool)
|
>> ParametricAttention(width,)
|
||||||
>> Residual(ReLu(width*2, width*2))
|
>> Pooling(sum_pool)
|
||||||
|
>> ReLu(width, width)
|
||||||
|
>> zero_init(Affine(nr_class, width, drop_factor=0.0))
|
||||||
)
|
)
|
||||||
linear_model = (
|
linear_model = (
|
||||||
_preprocess_doc
|
_preprocess_doc
|
||||||
>> LinearModel(nr_class)
|
>> LinearModel(nr_class, drop_factor=0.)
|
||||||
>> logistic
|
|
||||||
)
|
)
|
||||||
|
|
||||||
model = (
|
model = (
|
||||||
#(linear_model | cnn_model)
|
(linear_model | cnn_model)
|
||||||
cnn_model
|
>> zero_init(Affine(nr_class, nr_class*2, drop_factor=0.0))
|
||||||
>> zero_init(Affine(nr_class, width*2+nr_class, drop_factor=0.0))
|
|
||||||
>> logistic
|
>> logistic
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,7 @@ import six
|
||||||
import ftfy
|
import ftfy
|
||||||
import sys
|
import sys
|
||||||
import ujson
|
import ujson
|
||||||
|
import itertools
|
||||||
|
|
||||||
from thinc.neural.util import copy_array
|
from thinc.neural.util import copy_array
|
||||||
|
|
||||||
|
@ -35,6 +36,7 @@ CudaStream = CudaStream
|
||||||
cupy = cupy
|
cupy = cupy
|
||||||
fix_text = ftfy.fix_text
|
fix_text = ftfy.fix_text
|
||||||
copy_array = copy_array
|
copy_array = copy_array
|
||||||
|
izip = getattr(itertools, 'izip', zip)
|
||||||
|
|
||||||
is_python2 = six.PY2
|
is_python2 = six.PY2
|
||||||
is_python3 = six.PY3
|
is_python3 = six.PY3
|
||||||
|
|
|
@ -10,6 +10,7 @@ from thinc.neural.optimizers import Adam, SGD
|
||||||
import random
|
import random
|
||||||
import ujson
|
import ujson
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
import itertools
|
||||||
|
|
||||||
from .tokenizer import Tokenizer
|
from .tokenizer import Tokenizer
|
||||||
from .vocab import Vocab
|
from .vocab import Vocab
|
||||||
|
@ -25,7 +26,7 @@ from .pipeline import SimilarityHook
|
||||||
from .pipeline import TextCategorizer
|
from .pipeline import TextCategorizer
|
||||||
from . import about
|
from . import about
|
||||||
|
|
||||||
from .compat import json_dumps
|
from .compat import json_dumps, izip
|
||||||
from .attrs import IS_STOP
|
from .attrs import IS_STOP
|
||||||
from .lang.punctuation import TOKENIZER_PREFIXES, TOKENIZER_SUFFIXES, TOKENIZER_INFIXES
|
from .lang.punctuation import TOKENIZER_PREFIXES, TOKENIZER_SUFFIXES, TOKENIZER_INFIXES
|
||||||
from .lang.tokenizer_exceptions import TOKEN_MATCH
|
from .lang.tokenizer_exceptions import TOKEN_MATCH
|
||||||
|
@ -411,7 +412,7 @@ class Language(object):
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def pipe(self, texts, n_threads=2, batch_size=1000, disable=[]):
|
def pipe(self, texts, tuples=False, n_threads=2, batch_size=1000, disable=[]):
|
||||||
"""Process texts as a stream, and yield `Doc` objects in order. Supports
|
"""Process texts as a stream, and yield `Doc` objects in order. Supports
|
||||||
GIL-free multi-threading.
|
GIL-free multi-threading.
|
||||||
|
|
||||||
|
@ -427,8 +428,16 @@ class Language(object):
|
||||||
>>> for doc in nlp.pipe(texts, batch_size=50, n_threads=4):
|
>>> for doc in nlp.pipe(texts, batch_size=50, n_threads=4):
|
||||||
>>> assert doc.is_parsed
|
>>> assert doc.is_parsed
|
||||||
"""
|
"""
|
||||||
|
if tuples:
|
||||||
|
text_context1, text_context2 = itertools.tee(texts)
|
||||||
|
texts = (tc[0] for tc in text_context1)
|
||||||
|
contexts = (tc[1] for tc in text_context2)
|
||||||
|
docs = self.pipe(texts, n_threads=n_threads, batch_size=batch_size,
|
||||||
|
disable=disable)
|
||||||
|
for doc, context in izip(docs, contexts):
|
||||||
|
yield (doc, context)
|
||||||
|
return
|
||||||
docs = (self.make_doc(text) for text in texts)
|
docs = (self.make_doc(text) for text in texts)
|
||||||
docs = texts
|
|
||||||
for proc in self.pipeline:
|
for proc in self.pipeline:
|
||||||
name = getattr(proc, 'name', None)
|
name = getattr(proc, 'name', None)
|
||||||
if name in disable:
|
if name in disable:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user