reprocessing all of wikipedia for training data

This commit is contained in:
svlandeg 2019-06-16 21:14:45 +02:00
parent 81731907ba
commit 24db1392b9
6 changed files with 98 additions and 100 deletions

View File

@ -56,7 +56,7 @@ def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ,
frequency_list.append(freq)
filtered_title_to_id[title] = entity
print("Kept", len(filtered_title_to_id.keys()), "out of", len(title_to_id.keys()), "titles")
print("Kept", len(filtered_title_to_id.keys()), "out of", len(title_to_id.keys()), "titles with filter frequency", min_entity_freq)
print()
print(" * train entity encoder", datetime.datetime.now())

View File

@ -25,9 +25,7 @@ def run_kb_toy_example(kb):
def run_el_dev(nlp, kb, training_dir, limit=None):
correct_entries_per_article, _ = training_set_creator.read_training_entities(training_output=training_dir,
collect_correct=True,
collect_incorrect=False)
correct_entries_per_article, _ = training_set_creator.read_training_entities(training_output=training_dir)
predictions = list()
golds = list()

View File

@ -389,9 +389,7 @@ class EL_Model:
bp_sent(sent_gradients, sgd=self.sgd_sent)
def _get_training_data(self, training_dir, id_to_descr, dev, limit, to_print):
correct_entries, incorrect_entries = training_set_creator.read_training_entities(training_output=training_dir,
collect_correct=True,
collect_incorrect=True)
correct_entries, incorrect_entries = training_set_creator.read_training_entities(training_output=training_dir)
entities_by_cluster = dict()
gold_by_entity = dict()

View File

@ -16,12 +16,13 @@ from . import wikipedia_processor as wp, kb_creator
Process Wikipedia interlinks to generate a training dataset for the EL algorithm
"""
ENTITY_FILE = "gold_entities.csv"
# ENTITY_FILE = "gold_entities.csv"
ENTITY_FILE = "gold_entities_100000.csv" # use this file for faster processing
def create_training(entity_def_input, training_output):
wp_to_id = kb_creator._get_entity_to_id(entity_def_input)
_process_wikipedia_texts(wp_to_id, training_output, limit=100000000)
_process_wikipedia_texts(wp_to_id, training_output, limit=None)
def _process_wikipedia_texts(wp_to_id, training_output, limit=None):
@ -290,75 +291,72 @@ def _write_training_entity(outputfile, article_id, alias, entity, start, end):
outputfile.write(article_id + "|" + alias + "|" + entity + "|" + str(start) + "|" + str(end) + "\n")
def read_training_entities(training_output):
def is_dev(article_id):
return article_id.endswith("3")
def read_training_entities(training_output, dev, limit):
entityfile_loc = training_output + "/" + ENTITY_FILE
entries_per_article = dict()
article_ids = set()
with open(entityfile_loc, mode='r', encoding='utf8') as file:
for line in file:
fields = line.replace('\n', "").split(sep='|')
article_id = fields[0]
alias = fields[1]
wp_title = fields[2]
start = fields[3]
end = fields[4]
if not limit or len(article_ids) < limit:
fields = line.replace('\n', "").split(sep='|')
article_id = fields[0]
if dev == is_dev(article_id) and article_id != "article_id":
article_ids.add(article_id)
entries_by_offset = entries_per_article.get(article_id, dict())
entries_by_offset[start + "-" + end] = (alias, wp_title)
alias = fields[1]
wp_title = fields[2]
start = fields[3]
end = fields[4]
entries_per_article[article_id] = entries_by_offset
entries_by_offset = entries_per_article.get(article_id, dict())
entries_by_offset[start + "-" + end] = (alias, wp_title)
entries_per_article[article_id] = entries_by_offset
return entries_per_article
def read_training(nlp, training_dir, dev, limit, to_print):
# This method will provide training examples that correspond to the entity annotations found by the nlp object
entries_per_article = read_training_entities(training_output=training_dir)
def read_training(nlp, training_dir, dev, limit):
# This method provides training examples that correspond to the entity annotations found by the nlp object
print("reading training entities")
entries_per_article = read_training_entities(training_output=training_dir, dev=dev, limit=limit)
print("done reading training entities")
data = []
for article_id, entries_by_offset in entries_per_article.items():
file_name = article_id + ".txt"
try:
# parse the article text
with open(os.path.join(training_dir, file_name), mode="r", encoding='utf8') as file:
text = file.read()
article_doc = nlp(text)
cnt = 0
files = listdir(training_dir)
for f in files:
if not limit or cnt < limit:
if dev == run_el.is_dev(f):
article_id = f.replace(".txt", "")
if cnt % 500 == 0 and to_print:
print(datetime.datetime.now(), "processed", cnt, "files in the training dataset")
gold_entities = list()
for ent in article_doc.ents:
start = ent.start_char
end = ent.end_char
try:
# parse the article text
with open(os.path.join(training_dir, f), mode="r", encoding='utf8') as file:
text = file.read()
article_doc = nlp(text)
entity_tuple = entries_by_offset.get(str(start) + "-" + str(end), None)
if entity_tuple:
alias, wp_title = entity_tuple
if ent.text != alias:
print("Non-matching entity in", article_id, start, end)
else:
gold_entities.append((start, end, wp_title))
entries_by_offset = entries_per_article.get(article_id, dict())
if gold_entities:
gold = GoldParse(doc=article_doc, links=gold_entities)
data.append((article_doc, gold))
gold_entities = list()
for ent in article_doc.ents:
start = ent.start_char
end = ent.end_char
except Exception as e:
print("Problem parsing article", article_id)
print(e)
raise e
entity_tuple = entries_by_offset.get(str(start) + "-" + str(end), None)
if entity_tuple:
alias, wp_title = entity_tuple
if ent.text != alias:
print("Non-matching entity in", article_id, start, end)
else:
gold_entities.append((start, end, wp_title))
if gold_entities:
gold = GoldParse(doc=article_doc, links=gold_entities)
data.append((article_doc, gold))
cnt += 1
except Exception as e:
print("Problem parsing article", article_id)
print(e)
raise e
if to_print:
print()
print("Processed", cnt, "training articles, dev=" + str(dev))
print()
return data

View File

@ -29,7 +29,7 @@ NLP_2_DIR = 'C:/Users/Sofie/Documents/data/wikipedia/nlp_2'
TRAINING_DIR = 'C:/Users/Sofie/Documents/data/wikipedia/training_data_nel/'
MAX_CANDIDATES = 10
MIN_ENTITY_FREQ = 200
MIN_ENTITY_FREQ = 20
MIN_PAIR_OCC = 5
DOC_SENT_CUTOFF = 2
EPOCHS = 10
@ -47,21 +47,21 @@ def run_pipeline():
# one-time methods to create KB and write to file
to_create_prior_probs = False
to_create_entity_counts = False
to_create_kb = True
to_create_kb = False
# read KB back in from file
to_read_kb = False
to_read_kb = True
to_test_kb = False
# create training dataset
create_wp_training = False
# train the EL pipe
train_pipe = False
measure_performance = False
train_pipe = True
measure_performance = True
# test the EL pipe on a simple example
to_test_pipeline = False
to_test_pipeline = True
# write the NLP object, read back in and test again
test_nlp_io = False
@ -135,46 +135,50 @@ def run_pipeline():
print("STEP 6: training Entity Linking pipe", datetime.datetime.now())
train_limit = 10
dev_limit = 2
print("Training on", train_limit, "articles")
print("Dev testing on", dev_limit, "articles")
print()
train_data = training_set_creator.read_training(nlp=nlp_2,
training_dir=TRAINING_DIR,
dev=False,
limit=train_limit,
to_print=False)
limit=train_limit)
print("Training on", len(train_data), "articles")
print()
if not train_data:
print("Did not find any training data")
else:
for itn in range(EPOCHS):
random.shuffle(train_data)
losses = {}
batches = minibatch(train_data, size=compounding(4.0, 128.0, 1.001))
batchnr = 0
with nlp_2.disable_pipes(*other_pipes):
for batch in batches:
try:
docs, golds = zip(*batch)
nlp_2.update(
docs,
golds,
drop=DROPOUT,
losses=losses,
)
batchnr += 1
except Exception as e:
print("Error updating batch", e)
losses['entity_linker'] = losses['entity_linker'] / batchnr
print("Epoch, train loss", itn, round(losses['entity_linker'], 2))
dev_data = training_set_creator.read_training(nlp=nlp_2,
training_dir=TRAINING_DIR,
dev=True,
limit=dev_limit,
to_print=False)
limit=dev_limit)
print("Dev testing on", len(dev_data), "articles")
print()
for itn in range(EPOCHS):
random.shuffle(train_data)
losses = {}
batches = minibatch(train_data, size=compounding(4.0, 128.0, 1.001))
batchnr = 0
with nlp_2.disable_pipes(*other_pipes):
for batch in batches:
try:
docs, golds = zip(*batch)
nlp_2.update(
docs,
golds,
drop=DROPOUT,
losses=losses,
)
batchnr += 1
except Exception as e:
print("Error updating batch", e)
losses['entity_linker'] = losses['entity_linker'] / batchnr
print("Epoch, train loss", itn, round(losses['entity_linker'], 2))
if measure_performance:
if len(dev_data) and measure_performance:
print()
print("STEP 7: performance measurement of Entity Linking pipe", datetime.datetime.now())
print()

View File

@ -104,7 +104,7 @@ def read_wikidata_entities_json(limit=None, to_print=False):
if lang_aliases:
for item in lang_aliases:
if to_print:
print("alias (" + lang + "):", item["value"])
print("alias (" + lang + "):", item["value"])
if to_print:
print()