mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-03 05:04:09 +03:00
fixing the context/prior weight settings
This commit is contained in:
parent
0ea52c86b8
commit
b7a0c9bf60
|
@ -154,8 +154,8 @@ def run_pipeline():
|
|||
optimizer.L2 = L2
|
||||
|
||||
# define the size (nr of entities) of training and dev set
|
||||
train_limit = 5
|
||||
dev_limit = 5
|
||||
train_limit = 5000
|
||||
dev_limit = 5000
|
||||
|
||||
train_data = training_set_creator.read_training(nlp=nlp_2,
|
||||
training_dir=TRAINING_DIR,
|
||||
|
@ -198,9 +198,8 @@ def run_pipeline():
|
|||
print("Error updating batch:", e)
|
||||
|
||||
if batchnr > 0:
|
||||
with el_pipe.model.use_params(optimizer.averages):
|
||||
el_pipe.context_weight = 1
|
||||
el_pipe.prior_weight = 1
|
||||
el_pipe.cfg["context_weight"] = 1
|
||||
el_pipe.cfg["prior_weight"] = 1
|
||||
dev_acc_context, dev_acc_context_dict = _measure_accuracy(dev_data, el_pipe)
|
||||
losses['entity_linker'] = losses['entity_linker'] / batchnr
|
||||
print("Epoch, train loss", itn, round(losses['entity_linker'], 2),
|
||||
|
@ -218,24 +217,19 @@ def run_pipeline():
|
|||
print("dev acc random:", round(acc_r, 3), [(x, round(y, 3)) for x, y in acc_r_label.items()])
|
||||
print("dev acc prior:", round(acc_p, 3), [(x, round(y, 3)) for x, y in acc_p_label.items()])
|
||||
|
||||
with el_pipe.model.use_params(optimizer.averages):
|
||||
# measuring combined accuracy (prior + context)
|
||||
el_pipe.context_weight = 1
|
||||
el_pipe.prior_weight = 1
|
||||
dev_acc_combo, dev_acc_combo_dict = _measure_accuracy(dev_data, el_pipe, error_analysis=False)
|
||||
print("dev acc combo avg:", round(dev_acc_combo, 3),
|
||||
[(x, round(y, 3)) for x, y in dev_acc_combo_dict.items()])
|
||||
|
||||
# using only context
|
||||
el_pipe.context_weight = 1
|
||||
el_pipe.prior_weight = 0
|
||||
el_pipe.cfg["context_weight"] = 1
|
||||
el_pipe.cfg["prior_weight"] = 0
|
||||
dev_acc_context, dev_acc_context_dict = _measure_accuracy(dev_data, el_pipe)
|
||||
print("dev acc context avg:", round(dev_acc_context, 3),
|
||||
[(x, round(y, 3)) for x, y in dev_acc_context_dict.items()])
|
||||
|
||||
# reset for follow-up tests
|
||||
el_pipe.context_weight = 1
|
||||
el_pipe.prior_weight = 1
|
||||
# measuring combined accuracy (prior + context)
|
||||
el_pipe.cfg["context_weight"] = 1
|
||||
el_pipe.cfg["prior_weight"] = 1
|
||||
dev_acc_combo, dev_acc_combo_dict = _measure_accuracy(dev_data, el_pipe, error_analysis=False)
|
||||
print("dev acc combo avg:", round(dev_acc_combo, 3),
|
||||
[(x, round(y, 3)) for x, y in dev_acc_combo_dict.items()])
|
||||
|
||||
# STEP 8: apply the EL pipe on a toy example
|
||||
if to_test_pipeline:
|
||||
|
@ -250,7 +244,6 @@ def run_pipeline():
|
|||
print("STEP 9: testing NLP IO", datetime.datetime.now())
|
||||
print()
|
||||
print("writing to", NLP_2_DIR)
|
||||
with el_pipe.model.use_params(optimizer.averages) and el_pipe.model.tok2vec.use_params(el_pipe.sgd_context.averages):
|
||||
nlp_2.to_disk(NLP_2_DIR)
|
||||
print()
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user