mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-13 05:07:03 +03:00
Improve pretrain textcat example
This commit is contained in:
parent
3e7a96f99d
commit
bc8cda818c
|
@ -47,8 +47,8 @@ def load_textcat_data(limit=0):
|
||||||
train_data = train_data[-limit:]
|
train_data = train_data[-limit:]
|
||||||
texts, labels = zip(*train_data)
|
texts, labels = zip(*train_data)
|
||||||
eval_texts, eval_labels = zip(*eval_data)
|
eval_texts, eval_labels = zip(*eval_data)
|
||||||
cats = [{'POSITIVE': bool(y)} for y in labels]
|
cats = [{'POSITIVE': bool(y), 'NEGATIVE': not bool(y)} for y in labels]
|
||||||
eval_cats = [{'POSITIVE': bool(y)} for y in eval_labels]
|
eval_cats = [{'POSITIVE': bool(y), 'NEGATIVE': not bool(y)} for y in eval_labels]
|
||||||
return (texts, cats), (eval_texts, eval_cats)
|
return (texts, cats), (eval_texts, eval_cats)
|
||||||
|
|
||||||
|
|
||||||
|
@ -61,9 +61,9 @@ def prefer_gpu():
|
||||||
|
|
||||||
|
|
||||||
def build_textcat_model(tok2vec, nr_class, width):
|
def build_textcat_model(tok2vec, nr_class, width):
|
||||||
from thinc.v2v import Model, Affine, Maxout
|
from thinc.v2v import Model, Softmax, Maxout
|
||||||
from thinc.api import flatten_add_lengths, chain
|
from thinc.api import flatten_add_lengths, chain
|
||||||
from thinc.t2v import Pooling, sum_pool, max_pool
|
from thinc.t2v import Pooling, sum_pool, mean_pool, max_pool
|
||||||
from thinc.misc import Residual, LayerNorm
|
from thinc.misc import Residual, LayerNorm
|
||||||
from spacy._ml import logistic, zero_init
|
from spacy._ml import logistic, zero_init
|
||||||
|
|
||||||
|
@ -71,11 +71,8 @@ def build_textcat_model(tok2vec, nr_class, width):
|
||||||
model = (
|
model = (
|
||||||
tok2vec
|
tok2vec
|
||||||
>> flatten_add_lengths
|
>> flatten_add_lengths
|
||||||
>> Pooling(sum_pool, max_pool)
|
>> Pooling(mean_pool)
|
||||||
>> Residual(LayerNorm(Maxout(width*2, width*2, pieces=3)))
|
>> Softmax(nr_class, width)
|
||||||
>> Residual(LayerNorm(Maxout(width*2, width*2, pieces=3)))
|
|
||||||
>> zero_init(Affine(nr_class, width*2, drop_factor=0.0))
|
|
||||||
>> logistic
|
|
||||||
)
|
)
|
||||||
model.tok2vec = tok2vec
|
model.tok2vec = tok2vec
|
||||||
return model
|
return model
|
||||||
|
@ -92,9 +89,9 @@ def create_pipeline(width, embed_size, vectors_model):
|
||||||
nlp = spacy.load(vectors_model)
|
nlp = spacy.load(vectors_model)
|
||||||
print("Start training")
|
print("Start training")
|
||||||
textcat = TextCategorizer(nlp.vocab,
|
textcat = TextCategorizer(nlp.vocab,
|
||||||
labels=['POSITIVE'],
|
labels=['POSITIVE', 'NEGATIVE'],
|
||||||
model=build_textcat_model(
|
model=build_textcat_model(
|
||||||
Tok2Vec(width=width, embed_size=embed_size), 1, width))
|
Tok2Vec(width=width, embed_size=embed_size), 2, width))
|
||||||
|
|
||||||
nlp.add_pipe(textcat)
|
nlp.add_pipe(textcat)
|
||||||
return nlp
|
return nlp
|
||||||
|
|
Loading…
Reference in New Issue
Block a user