mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 17:06:29 +03:00
Clarify train textcat example
This commit is contained in:
parent
27abc56e98
commit
c16ef0a85c
|
@ -28,8 +28,8 @@ def train_textcat(tokenizer, textcat,
|
||||||
batch_sizes = compounding(4., 128., 1.001)
|
batch_sizes = compounding(4., 128., 1.001)
|
||||||
for i in range(n_iter):
|
for i in range(n_iter):
|
||||||
losses = {}
|
losses = {}
|
||||||
for batch in minibatch(tqdm.tqdm(train_data, leave=False),
|
train_data = tqdm.tqdm(train_data, leave=False) # Progress bar
|
||||||
size=batch_sizes):
|
for batch in minibatch(train_data, size=batch_sizes):
|
||||||
docs, golds = zip(*batch)
|
docs, golds = zip(*batch)
|
||||||
textcat.update((docs, None), golds, sgd=optimizer, drop=0.2,
|
textcat.update((docs, None), golds, sgd=optimizer, drop=0.2,
|
||||||
losses=losses)
|
losses=losses)
|
||||||
|
@ -70,7 +70,7 @@ def load_data():
|
||||||
|
|
||||||
texts, labels = zip(*train_data)
|
texts, labels = zip(*train_data)
|
||||||
cats = [(['POSITIVE'] if y else []) for y in labels]
|
cats = [(['POSITIVE'] if y else []) for y in labels]
|
||||||
|
|
||||||
split = int(len(train_data) * 0.8)
|
split = int(len(train_data) * 0.8)
|
||||||
|
|
||||||
train_texts = texts[:split]
|
train_texts = texts[:split]
|
||||||
|
@ -104,7 +104,6 @@ def main(model_loc=None):
|
||||||
doc = nlp(u'This movie sucked!')
|
doc = nlp(u'This movie sucked!')
|
||||||
print(doc.cats)
|
print(doc.cats)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
plac.call(main)
|
plac.call(main)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user