mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-13 10:46:29 +03:00
reprocessing all of wikipedia for training data
This commit is contained in:
parent
81731907ba
commit
24db1392b9
|
@ -56,7 +56,7 @@ def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ,
|
|||
frequency_list.append(freq)
|
||||
filtered_title_to_id[title] = entity
|
||||
|
||||
print("Kept", len(filtered_title_to_id.keys()), "out of", len(title_to_id.keys()), "titles")
|
||||
print("Kept", len(filtered_title_to_id.keys()), "out of", len(title_to_id.keys()), "titles with filter frequency", min_entity_freq)
|
||||
|
||||
print()
|
||||
print(" * train entity encoder", datetime.datetime.now())
|
||||
|
|
|
@ -25,9 +25,7 @@ def run_kb_toy_example(kb):
|
|||
|
||||
|
||||
def run_el_dev(nlp, kb, training_dir, limit=None):
|
||||
correct_entries_per_article, _ = training_set_creator.read_training_entities(training_output=training_dir,
|
||||
collect_correct=True,
|
||||
collect_incorrect=False)
|
||||
correct_entries_per_article, _ = training_set_creator.read_training_entities(training_output=training_dir)
|
||||
|
||||
predictions = list()
|
||||
golds = list()
|
||||
|
|
|
@ -389,9 +389,7 @@ class EL_Model:
|
|||
bp_sent(sent_gradients, sgd=self.sgd_sent)
|
||||
|
||||
def _get_training_data(self, training_dir, id_to_descr, dev, limit, to_print):
|
||||
correct_entries, incorrect_entries = training_set_creator.read_training_entities(training_output=training_dir,
|
||||
collect_correct=True,
|
||||
collect_incorrect=True)
|
||||
correct_entries, incorrect_entries = training_set_creator.read_training_entities(training_output=training_dir)
|
||||
|
||||
entities_by_cluster = dict()
|
||||
gold_by_entity = dict()
|
||||
|
|
|
@ -16,12 +16,13 @@ from . import wikipedia_processor as wp, kb_creator
|
|||
Process Wikipedia interlinks to generate a training dataset for the EL algorithm
|
||||
"""
|
||||
|
||||
ENTITY_FILE = "gold_entities.csv"
|
||||
# ENTITY_FILE = "gold_entities.csv"
|
||||
ENTITY_FILE = "gold_entities_100000.csv" # use this file for faster processing
|
||||
|
||||
|
||||
def create_training(entity_def_input, training_output):
|
||||
wp_to_id = kb_creator._get_entity_to_id(entity_def_input)
|
||||
_process_wikipedia_texts(wp_to_id, training_output, limit=100000000)
|
||||
_process_wikipedia_texts(wp_to_id, training_output, limit=None)
|
||||
|
||||
|
||||
def _process_wikipedia_texts(wp_to_id, training_output, limit=None):
|
||||
|
@ -290,14 +291,23 @@ def _write_training_entity(outputfile, article_id, alias, entity, start, end):
|
|||
outputfile.write(article_id + "|" + alias + "|" + entity + "|" + str(start) + "|" + str(end) + "\n")
|
||||
|
||||
|
||||
def read_training_entities(training_output):
|
||||
def is_dev(article_id):
|
||||
return article_id.endswith("3")
|
||||
|
||||
|
||||
def read_training_entities(training_output, dev, limit):
|
||||
entityfile_loc = training_output + "/" + ENTITY_FILE
|
||||
entries_per_article = dict()
|
||||
article_ids = set()
|
||||
|
||||
with open(entityfile_loc, mode='r', encoding='utf8') as file:
|
||||
for line in file:
|
||||
if not limit or len(article_ids) < limit:
|
||||
fields = line.replace('\n', "").split(sep='|')
|
||||
article_id = fields[0]
|
||||
if dev == is_dev(article_id) and article_id != "article_id":
|
||||
article_ids.add(article_id)
|
||||
|
||||
alias = fields[1]
|
||||
wp_title = fields[2]
|
||||
start = fields[3]
|
||||
|
@ -311,29 +321,22 @@ def read_training_entities(training_output):
|
|||
return entries_per_article
|
||||
|
||||
|
||||
def read_training(nlp, training_dir, dev, limit, to_print):
|
||||
# This method will provide training examples that correspond to the entity annotations found by the nlp object
|
||||
entries_per_article = read_training_entities(training_output=training_dir)
|
||||
def read_training(nlp, training_dir, dev, limit):
|
||||
# This method provides training examples that correspond to the entity annotations found by the nlp object
|
||||
|
||||
print("reading training entities")
|
||||
entries_per_article = read_training_entities(training_output=training_dir, dev=dev, limit=limit)
|
||||
print("done reading training entities")
|
||||
|
||||
data = []
|
||||
|
||||
cnt = 0
|
||||
files = listdir(training_dir)
|
||||
for f in files:
|
||||
if not limit or cnt < limit:
|
||||
if dev == run_el.is_dev(f):
|
||||
article_id = f.replace(".txt", "")
|
||||
if cnt % 500 == 0 and to_print:
|
||||
print(datetime.datetime.now(), "processed", cnt, "files in the training dataset")
|
||||
|
||||
for article_id, entries_by_offset in entries_per_article.items():
|
||||
file_name = article_id + ".txt"
|
||||
try:
|
||||
# parse the article text
|
||||
with open(os.path.join(training_dir, f), mode="r", encoding='utf8') as file:
|
||||
with open(os.path.join(training_dir, file_name), mode="r", encoding='utf8') as file:
|
||||
text = file.read()
|
||||
article_doc = nlp(text)
|
||||
|
||||
entries_by_offset = entries_per_article.get(article_id, dict())
|
||||
|
||||
gold_entities = list()
|
||||
for ent in article_doc.ents:
|
||||
start = ent.start_char
|
||||
|
@ -351,14 +354,9 @@ def read_training(nlp, training_dir, dev, limit, to_print):
|
|||
gold = GoldParse(doc=article_doc, links=gold_entities)
|
||||
data.append((article_doc, gold))
|
||||
|
||||
cnt += 1
|
||||
except Exception as e:
|
||||
print("Problem parsing article", article_id)
|
||||
print(e)
|
||||
raise e
|
||||
|
||||
if to_print:
|
||||
print()
|
||||
print("Processed", cnt, "training articles, dev=" + str(dev))
|
||||
print()
|
||||
return data
|
||||
|
|
|
@ -29,7 +29,7 @@ NLP_2_DIR = 'C:/Users/Sofie/Documents/data/wikipedia/nlp_2'
|
|||
TRAINING_DIR = 'C:/Users/Sofie/Documents/data/wikipedia/training_data_nel/'
|
||||
|
||||
MAX_CANDIDATES = 10
|
||||
MIN_ENTITY_FREQ = 200
|
||||
MIN_ENTITY_FREQ = 20
|
||||
MIN_PAIR_OCC = 5
|
||||
DOC_SENT_CUTOFF = 2
|
||||
EPOCHS = 10
|
||||
|
@ -47,21 +47,21 @@ def run_pipeline():
|
|||
# one-time methods to create KB and write to file
|
||||
to_create_prior_probs = False
|
||||
to_create_entity_counts = False
|
||||
to_create_kb = True
|
||||
to_create_kb = False
|
||||
|
||||
# read KB back in from file
|
||||
to_read_kb = False
|
||||
to_read_kb = True
|
||||
to_test_kb = False
|
||||
|
||||
# create training dataset
|
||||
create_wp_training = False
|
||||
|
||||
# train the EL pipe
|
||||
train_pipe = False
|
||||
measure_performance = False
|
||||
train_pipe = True
|
||||
measure_performance = True
|
||||
|
||||
# test the EL pipe on a simple example
|
||||
to_test_pipeline = False
|
||||
to_test_pipeline = True
|
||||
|
||||
# write the NLP object, read back in and test again
|
||||
test_nlp_io = False
|
||||
|
@ -135,22 +135,19 @@ def run_pipeline():
|
|||
print("STEP 6: training Entity Linking pipe", datetime.datetime.now())
|
||||
train_limit = 10
|
||||
dev_limit = 2
|
||||
print("Training on", train_limit, "articles")
|
||||
print("Dev testing on", dev_limit, "articles")
|
||||
print()
|
||||
|
||||
train_data = training_set_creator.read_training(nlp=nlp_2,
|
||||
training_dir=TRAINING_DIR,
|
||||
dev=False,
|
||||
limit=train_limit,
|
||||
to_print=False)
|
||||
limit=train_limit)
|
||||
|
||||
dev_data = training_set_creator.read_training(nlp=nlp_2,
|
||||
training_dir=TRAINING_DIR,
|
||||
dev=True,
|
||||
limit=dev_limit,
|
||||
to_print=False)
|
||||
print("Training on", len(train_data), "articles")
|
||||
print()
|
||||
|
||||
if not train_data:
|
||||
print("Did not find any training data")
|
||||
|
||||
else:
|
||||
for itn in range(EPOCHS):
|
||||
random.shuffle(train_data)
|
||||
losses = {}
|
||||
|
@ -174,7 +171,14 @@ def run_pipeline():
|
|||
losses['entity_linker'] = losses['entity_linker'] / batchnr
|
||||
print("Epoch, train loss", itn, round(losses['entity_linker'], 2))
|
||||
|
||||
if measure_performance:
|
||||
dev_data = training_set_creator.read_training(nlp=nlp_2,
|
||||
training_dir=TRAINING_DIR,
|
||||
dev=True,
|
||||
limit=dev_limit)
|
||||
print("Dev testing on", len(dev_data), "articles")
|
||||
print()
|
||||
|
||||
if len(dev_data) and measure_performance:
|
||||
print()
|
||||
print("STEP 7: performance measurement of Entity Linking pipe", datetime.datetime.now())
|
||||
print()
|
||||
|
|
Loading…
Reference in New Issue
Block a user