mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
Improve pretrain textcat example
This commit is contained in:
parent
3e7a96f99d
commit
bc8cda818c
|
@ -47,8 +47,8 @@ def load_textcat_data(limit=0):
|
|||
train_data = train_data[-limit:]
|
||||
texts, labels = zip(*train_data)
|
||||
eval_texts, eval_labels = zip(*eval_data)
|
||||
cats = [{'POSITIVE': bool(y)} for y in labels]
|
||||
eval_cats = [{'POSITIVE': bool(y)} for y in eval_labels]
|
||||
cats = [{'POSITIVE': bool(y), 'NEGATIVE': not bool(y)} for y in labels]
|
||||
eval_cats = [{'POSITIVE': bool(y), 'NEGATIVE': not bool(y)} for y in eval_labels]
|
||||
return (texts, cats), (eval_texts, eval_cats)
|
||||
|
||||
|
||||
|
@ -61,9 +61,9 @@ def prefer_gpu():
|
|||
|
||||
|
||||
def build_textcat_model(tok2vec, nr_class, width):
|
||||
from thinc.v2v import Model, Affine, Maxout
|
||||
from thinc.v2v import Model, Softmax, Maxout
|
||||
from thinc.api import flatten_add_lengths, chain
|
||||
from thinc.t2v import Pooling, sum_pool, max_pool
|
||||
from thinc.t2v import Pooling, sum_pool, mean_pool, max_pool
|
||||
from thinc.misc import Residual, LayerNorm
|
||||
from spacy._ml import logistic, zero_init
|
||||
|
||||
|
@ -71,11 +71,8 @@ def build_textcat_model(tok2vec, nr_class, width):
|
|||
model = (
|
||||
tok2vec
|
||||
>> flatten_add_lengths
|
||||
>> Pooling(sum_pool, max_pool)
|
||||
>> Residual(LayerNorm(Maxout(width*2, width*2, pieces=3)))
|
||||
>> Residual(LayerNorm(Maxout(width*2, width*2, pieces=3)))
|
||||
>> zero_init(Affine(nr_class, width*2, drop_factor=0.0))
|
||||
>> logistic
|
||||
>> Pooling(mean_pool)
|
||||
>> Softmax(nr_class, width)
|
||||
)
|
||||
model.tok2vec = tok2vec
|
||||
return model
|
||||
|
@ -92,9 +89,9 @@ def create_pipeline(width, embed_size, vectors_model):
|
|||
nlp = spacy.load(vectors_model)
|
||||
print("Start training")
|
||||
textcat = TextCategorizer(nlp.vocab,
|
||||
labels=['POSITIVE'],
|
||||
labels=['POSITIVE', 'NEGATIVE'],
|
||||
model=build_textcat_model(
|
||||
Tok2Vec(width=width, embed_size=embed_size), 1, width))
|
||||
Tok2Vec(width=width, embed_size=embed_size), 2, width))
|
||||
|
||||
nlp.add_pipe(textcat)
|
||||
return nlp
|
||||
|
|
Loading…
Reference in New Issue
Block a user