mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-14 03:26:24 +03:00
speeding up training
This commit is contained in:
parent
66813a1fdc
commit
6521cfa132
|
@ -115,6 +115,7 @@ def run_pipeline():
|
||||||
|
|
||||||
# STEP 6: create the entity linking pipe
|
# STEP 6: create the entity linking pipe
|
||||||
if train_pipe:
|
if train_pipe:
|
||||||
|
print("STEP 6: training Entity Linking pipe", datetime.datetime.now())
|
||||||
train_limit = 100
|
train_limit = 100
|
||||||
dev_limit = 20
|
dev_limit = 20
|
||||||
print("Training on", train_limit, "articles")
|
print("Training on", train_limit, "articles")
|
||||||
|
@ -131,7 +132,7 @@ def run_pipeline():
|
||||||
training_dir=TRAINING_DIR,
|
training_dir=TRAINING_DIR,
|
||||||
dev=True,
|
dev=True,
|
||||||
limit=dev_limit,
|
limit=dev_limit,
|
||||||
to_print=False)
|
to_print=False)
|
||||||
|
|
||||||
el_pipe = nlp.create_pipe(name='entity_linker', config={"kb": my_kb, "doc_cutoff": DOC_CHAR_CUTOFF})
|
el_pipe = nlp.create_pipe(name='entity_linker', config={"kb": my_kb, "doc_cutoff": DOC_CHAR_CUTOFF})
|
||||||
nlp.add_pipe(el_pipe, last=True)
|
nlp.add_pipe(el_pipe, last=True)
|
||||||
|
@ -147,35 +148,40 @@ def run_pipeline():
|
||||||
|
|
||||||
with nlp.disable_pipes(*other_pipes):
|
with nlp.disable_pipes(*other_pipes):
|
||||||
for batch in batches:
|
for batch in batches:
|
||||||
docs, golds = zip(*batch)
|
try:
|
||||||
nlp.update(
|
docs, golds = zip(*batch)
|
||||||
docs,
|
nlp.update(
|
||||||
golds,
|
docs,
|
||||||
drop=DROPOUT,
|
golds,
|
||||||
losses=losses,
|
drop=DROPOUT,
|
||||||
)
|
losses=losses,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print("Error updating batch", e)
|
||||||
|
|
||||||
# print(" measuring accuracy 1-1")
|
print("Epoch, train loss", itn, round(losses['entity_linker'], 2))
|
||||||
el_pipe.context_weight = 1
|
|
||||||
el_pipe.prior_weight = 1
|
|
||||||
dev_acc_1_1 = _measure_accuracy(dev_data, el_pipe)
|
|
||||||
train_acc_1_1 = _measure_accuracy(train_data, el_pipe)
|
|
||||||
|
|
||||||
# print(" measuring accuracy 0-1")
|
# baseline using only prior probabilities
|
||||||
el_pipe.context_weight = 0
|
el_pipe.context_weight = 0
|
||||||
el_pipe.prior_weight = 1
|
el_pipe.prior_weight = 1
|
||||||
dev_acc_0_1 = _measure_accuracy(dev_data, el_pipe)
|
dev_acc_0_1 = _measure_accuracy(dev_data, el_pipe)
|
||||||
train_acc_0_1 = _measure_accuracy(train_data, el_pipe)
|
train_acc_0_1 = _measure_accuracy(train_data, el_pipe)
|
||||||
|
|
||||||
# print(" measuring accuracy 1-0")
|
# print(" measuring accuracy 1-1")
|
||||||
el_pipe.context_weight = 1
|
el_pipe.context_weight = 1
|
||||||
el_pipe.prior_weight = 0
|
el_pipe.prior_weight = 1
|
||||||
dev_acc_1_0 = _measure_accuracy(dev_data, el_pipe)
|
dev_acc_1_1 = _measure_accuracy(dev_data, el_pipe)
|
||||||
train_acc_1_0 = _measure_accuracy(train_data, el_pipe)
|
train_acc_1_1 = _measure_accuracy(train_data, el_pipe)
|
||||||
|
|
||||||
print("Epoch, train loss, train/dev acc, 1-1, 0-1, 1-0:", itn, round(losses['entity_linker'], 2),
|
# print(" measuring accuracy 1-0")
|
||||||
round(train_acc_1_1, 2), round(train_acc_0_1, 2), round(train_acc_1_0, 2), "/",
|
el_pipe.context_weight = 1
|
||||||
round(dev_acc_1_1, 2), round(dev_acc_0_1, 2), round(dev_acc_1_0, 2))
|
el_pipe.prior_weight = 0
|
||||||
|
dev_acc_1_0 = _measure_accuracy(dev_data, el_pipe)
|
||||||
|
train_acc_1_0 = _measure_accuracy(train_data, el_pipe)
|
||||||
|
|
||||||
|
print("train/dev acc, 1-1, 0-1, 1-0:" ,
|
||||||
|
round(train_acc_1_1, 2), round(train_acc_0_1, 2), round(train_acc_1_0, 2), "/",
|
||||||
|
round(dev_acc_1_1, 2), round(dev_acc_0_1, 2), round(dev_acc_1_0, 2))
|
||||||
|
|
||||||
# test Entity Linker
|
# test Entity Linker
|
||||||
if to_test_pipeline:
|
if to_test_pipeline:
|
||||||
|
@ -193,26 +199,29 @@ def _measure_accuracy(data, el_pipe):
|
||||||
|
|
||||||
docs = [d for d, g in data]
|
docs = [d for d, g in data]
|
||||||
docs = el_pipe.pipe(docs)
|
docs = el_pipe.pipe(docs)
|
||||||
|
|
||||||
golds = [g for d, g in data]
|
golds = [g for d, g in data]
|
||||||
|
|
||||||
for doc, gold in zip(docs, golds):
|
for doc, gold in zip(docs, golds):
|
||||||
correct_entries_per_article = dict()
|
try:
|
||||||
for entity in gold.links:
|
correct_entries_per_article = dict()
|
||||||
start, end, gold_kb = entity
|
for entity in gold.links:
|
||||||
correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb
|
start, end, gold_kb = entity
|
||||||
|
correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb
|
||||||
|
|
||||||
for ent in doc.ents:
|
for ent in doc.ents:
|
||||||
if ent.label_ == "PERSON": # TODO: expand to other types
|
if ent.label_ == "PERSON": # TODO: expand to other types
|
||||||
pred_entity = ent.kb_id_
|
pred_entity = ent.kb_id_
|
||||||
start = ent.start
|
start = ent.start
|
||||||
end = ent.end
|
end = ent.end
|
||||||
gold_entity = correct_entries_per_article.get(str(start) + "-" + str(end), None)
|
gold_entity = correct_entries_per_article.get(str(start) + "-" + str(end), None)
|
||||||
if gold_entity is not None:
|
if gold_entity is not None:
|
||||||
if gold_entity == pred_entity:
|
if gold_entity == pred_entity:
|
||||||
correct += 1
|
correct += 1
|
||||||
else:
|
else:
|
||||||
incorrect += 1
|
incorrect += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print("Error assessing accuracy", e)
|
||||||
|
|
||||||
if correct == incorrect == 0:
|
if correct == incorrect == 0:
|
||||||
return 0
|
return 0
|
||||||
|
|
|
@ -1220,8 +1220,13 @@ class EntityLinker(Pipe):
|
||||||
|
|
||||||
def predict(self, docs):
|
def predict(self, docs):
|
||||||
self.require_model()
|
self.require_model()
|
||||||
|
|
||||||
|
if isinstance(docs, Doc):
|
||||||
|
docs = [docs]
|
||||||
|
|
||||||
final_entities = list()
|
final_entities = list()
|
||||||
final_kb_ids = list()
|
final_kb_ids = list()
|
||||||
|
|
||||||
for i, article_doc in enumerate(docs):
|
for i, article_doc in enumerate(docs):
|
||||||
doc_encoding = self.article_encoder([article_doc])
|
doc_encoding = self.article_encoder([article_doc])
|
||||||
for ent in article_doc.ents:
|
for ent in article_doc.ents:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user