mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 00:46:28 +03:00
Fix print statements in text classifier example
This commit is contained in:
parent
9f9439667b
commit
dad8f09fba
|
@ -26,8 +26,9 @@ from spacy.pipeline import TextCategorizer
|
|||
@plac.annotations(
|
||||
model=("Model name. Defaults to blank 'en' model.", "option", "m", str),
|
||||
output_dir=("Optional output directory", "option", "o", Path),
|
||||
n_examples=("Number of texts to train from", "option", "N", int),
|
||||
n_iter=("Number of training iterations", "option", "n", int))
|
||||
def main(model=None, output_dir=None, n_iter=20):
|
||||
def main(model=None, output_dir=None, n_iter=20, n_texts=2000):
|
||||
if model is not None:
|
||||
nlp = spacy.load(model) # load existing spaCy model
|
||||
print("Loaded model '%s'" % model)
|
||||
|
@ -50,7 +51,8 @@ def main(model=None, output_dir=None, n_iter=20):
|
|||
|
||||
# load the IMBD dataset
|
||||
print("Loading IMDB data...")
|
||||
(train_texts, train_cats), (dev_texts, dev_cats) = load_data(limit=2000)
|
||||
print("Using %d training examples" % n_texts)
|
||||
(train_texts, train_cats), (dev_texts, dev_cats) = load_data(limit=n_texts)
|
||||
train_docs = [nlp.tokenizer(text) for text in train_texts]
|
||||
train_gold = [GoldParse(doc, cats=cats) for doc, cats in
|
||||
zip(train_docs, train_cats)]
|
||||
|
@ -65,14 +67,14 @@ def main(model=None, output_dir=None, n_iter=20):
|
|||
for i in range(n_iter):
|
||||
losses = {}
|
||||
# batch up the examples using spaCy's minibatch
|
||||
batches = minibatch(train_data, size=compounding(4., 128., 1.001))
|
||||
batches = minibatch(train_data, size=compounding(4., 32., 1.001))
|
||||
for batch in batches:
|
||||
docs, golds = zip(*batch)
|
||||
nlp.update(docs, golds, sgd=optimizer, drop=0.2, losses=losses)
|
||||
with textcat.model.use_params(optimizer.averages):
|
||||
# evaluate on the dev data split off in load_data()
|
||||
scores = evaluate(nlp.tokenizer, textcat, dev_texts, dev_cats)
|
||||
print('{0:.3f}\t{0:.3f}\t{0:.3f}\t{0:.3f}' # print a simple table
|
||||
print('{0:.3f}\t{1:.3f}\t{2:.3f}\t{3:.3f}' # print a simple table
|
||||
.format(losses['textcat'], scores['textcat_p'],
|
||||
scores['textcat_r'], scores['textcat_f']))
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user