mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-27 17:54:39 +03:00
* Increase default number of iterations from 5 to 10
This commit is contained in:
parent
3cab1d9a29
commit
949a6245f9
|
@ -10,6 +10,7 @@ import random
|
||||||
import json
|
import json
|
||||||
import cython
|
import cython
|
||||||
|
|
||||||
|
|
||||||
from .context cimport fill_slots
|
from .context cimport fill_slots
|
||||||
from .context cimport fill_flat
|
from .context cimport fill_flat
|
||||||
from .context cimport N_FIELDS
|
from .context cimport N_FIELDS
|
||||||
|
@ -33,7 +34,7 @@ def setup_model_dir(tag_type, tag_names, templates, model_dir):
|
||||||
json.dump(config, file_)
|
json.dump(config, file_)
|
||||||
|
|
||||||
|
|
||||||
def train(train_sents, model_dir, nr_iter=5):
|
def train(train_sents, model_dir, nr_iter=10):
|
||||||
tagger = Tagger(model_dir)
|
tagger = Tagger(model_dir)
|
||||||
for _ in range(nr_iter):
|
for _ in range(nr_iter):
|
||||||
n_corr = 0
|
n_corr = 0
|
||||||
|
@ -43,15 +44,15 @@ def train(train_sents, model_dir, nr_iter=5):
|
||||||
for i, gold in enumerate(golds):
|
for i, gold in enumerate(golds):
|
||||||
guess = tagger.predict(i, tokens)
|
guess = tagger.predict(i, tokens)
|
||||||
tokens.set_tag(i, tagger.tag_type, guess)
|
tokens.set_tag(i, tagger.tag_type, guess)
|
||||||
tagger.tell_answer(gold)
|
|
||||||
if gold != NULL_TAG:
|
if gold != NULL_TAG:
|
||||||
|
tagger.tell_answer(gold)
|
||||||
total += 1
|
total += 1
|
||||||
n_corr += guess == gold
|
n_corr += guess == gold
|
||||||
#print('%s\t%d\t%d' % (tokens[i].string, guess, gold))
|
#print('%s\t%d\t%d' % (tokens[i].string, guess, gold))
|
||||||
print('%.4f' % ((n_corr / total) * 100))
|
print('%.4f' % ((n_corr / total) * 100))
|
||||||
random.shuffle(train_sents)
|
random.shuffle(train_sents)
|
||||||
tagger.model.end_training()
|
tagger.model.end_training()
|
||||||
tagger.model.dump(path.join(model_dir, 'model'), freq_thresh=10)
|
tagger.model.dump(path.join(model_dir, 'model'))
|
||||||
|
|
||||||
|
|
||||||
def evaluate(tagger, sents):
|
def evaluate(tagger, sents):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user