mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-13 10:00:34 +03:00
fixing the context/prior weight settings
This commit is contained in:
parent
0ea52c86b8
commit
b7a0c9bf60
|
@ -154,8 +154,8 @@ def run_pipeline():
|
||||||
optimizer.L2 = L2
|
optimizer.L2 = L2
|
||||||
|
|
||||||
# define the size (nr of entities) of training and dev set
|
# define the size (nr of entities) of training and dev set
|
||||||
train_limit = 5
|
train_limit = 5000
|
||||||
dev_limit = 5
|
dev_limit = 5000
|
||||||
|
|
||||||
train_data = training_set_creator.read_training(nlp=nlp_2,
|
train_data = training_set_creator.read_training(nlp=nlp_2,
|
||||||
training_dir=TRAINING_DIR,
|
training_dir=TRAINING_DIR,
|
||||||
|
@ -198,9 +198,8 @@ def run_pipeline():
|
||||||
print("Error updating batch:", e)
|
print("Error updating batch:", e)
|
||||||
|
|
||||||
if batchnr > 0:
|
if batchnr > 0:
|
||||||
with el_pipe.model.use_params(optimizer.averages):
|
el_pipe.cfg["context_weight"] = 1
|
||||||
el_pipe.context_weight = 1
|
el_pipe.cfg["prior_weight"] = 1
|
||||||
el_pipe.prior_weight = 1
|
|
||||||
dev_acc_context, dev_acc_context_dict = _measure_accuracy(dev_data, el_pipe)
|
dev_acc_context, dev_acc_context_dict = _measure_accuracy(dev_data, el_pipe)
|
||||||
losses['entity_linker'] = losses['entity_linker'] / batchnr
|
losses['entity_linker'] = losses['entity_linker'] / batchnr
|
||||||
print("Epoch, train loss", itn, round(losses['entity_linker'], 2),
|
print("Epoch, train loss", itn, round(losses['entity_linker'], 2),
|
||||||
|
@ -218,24 +217,19 @@ def run_pipeline():
|
||||||
print("dev acc random:", round(acc_r, 3), [(x, round(y, 3)) for x, y in acc_r_label.items()])
|
print("dev acc random:", round(acc_r, 3), [(x, round(y, 3)) for x, y in acc_r_label.items()])
|
||||||
print("dev acc prior:", round(acc_p, 3), [(x, round(y, 3)) for x, y in acc_p_label.items()])
|
print("dev acc prior:", round(acc_p, 3), [(x, round(y, 3)) for x, y in acc_p_label.items()])
|
||||||
|
|
||||||
with el_pipe.model.use_params(optimizer.averages):
|
|
||||||
# measuring combined accuracy (prior + context)
|
|
||||||
el_pipe.context_weight = 1
|
|
||||||
el_pipe.prior_weight = 1
|
|
||||||
dev_acc_combo, dev_acc_combo_dict = _measure_accuracy(dev_data, el_pipe, error_analysis=False)
|
|
||||||
print("dev acc combo avg:", round(dev_acc_combo, 3),
|
|
||||||
[(x, round(y, 3)) for x, y in dev_acc_combo_dict.items()])
|
|
||||||
|
|
||||||
# using only context
|
# using only context
|
||||||
el_pipe.context_weight = 1
|
el_pipe.cfg["context_weight"] = 1
|
||||||
el_pipe.prior_weight = 0
|
el_pipe.cfg["prior_weight"] = 0
|
||||||
dev_acc_context, dev_acc_context_dict = _measure_accuracy(dev_data, el_pipe)
|
dev_acc_context, dev_acc_context_dict = _measure_accuracy(dev_data, el_pipe)
|
||||||
print("dev acc context avg:", round(dev_acc_context, 3),
|
print("dev acc context avg:", round(dev_acc_context, 3),
|
||||||
[(x, round(y, 3)) for x, y in dev_acc_context_dict.items()])
|
[(x, round(y, 3)) for x, y in dev_acc_context_dict.items()])
|
||||||
|
|
||||||
# reset for follow-up tests
|
# measuring combined accuracy (prior + context)
|
||||||
el_pipe.context_weight = 1
|
el_pipe.cfg["context_weight"] = 1
|
||||||
el_pipe.prior_weight = 1
|
el_pipe.cfg["prior_weight"] = 1
|
||||||
|
dev_acc_combo, dev_acc_combo_dict = _measure_accuracy(dev_data, el_pipe, error_analysis=False)
|
||||||
|
print("dev acc combo avg:", round(dev_acc_combo, 3),
|
||||||
|
[(x, round(y, 3)) for x, y in dev_acc_combo_dict.items()])
|
||||||
|
|
||||||
# STEP 8: apply the EL pipe on a toy example
|
# STEP 8: apply the EL pipe on a toy example
|
||||||
if to_test_pipeline:
|
if to_test_pipeline:
|
||||||
|
@ -250,7 +244,6 @@ def run_pipeline():
|
||||||
print("STEP 9: testing NLP IO", datetime.datetime.now())
|
print("STEP 9: testing NLP IO", datetime.datetime.now())
|
||||||
print()
|
print()
|
||||||
print("writing to", NLP_2_DIR)
|
print("writing to", NLP_2_DIR)
|
||||||
with el_pipe.model.use_params(optimizer.averages) and el_pipe.model.tok2vec.use_params(el_pipe.sgd_context.averages):
|
|
||||||
nlp_2.to_disk(NLP_2_DIR)
|
nlp_2.to_disk(NLP_2_DIR)
|
||||||
print()
|
print()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user