mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-10 19:57:17 +03:00
using entity descriptions and article texts as input embedding vectors for training
This commit is contained in:
parent
7e348d7f7f
commit
9f33732b96
|
@ -4,13 +4,16 @@ from __future__ import unicode_literals
|
|||
import spacy
|
||||
from spacy.kb import KnowledgeBase
|
||||
|
||||
import csv
|
||||
import datetime
|
||||
|
||||
from . import wikipedia_processor as wp
|
||||
from . import wikidata_processor as wd
|
||||
|
||||
|
||||
def create_kb(vocab, max_entities_per_alias, min_occ, entity_output, count_input, prior_prob_input,
|
||||
def create_kb(vocab, max_entities_per_alias, min_occ,
|
||||
entity_def_output, entity_descr_output,
|
||||
count_input, prior_prob_input,
|
||||
to_print=False, write_entity_defs=True):
|
||||
""" Create the knowledge base from Wikidata entries """
|
||||
kb = KnowledgeBase(vocab=vocab)
|
||||
|
@ -18,15 +21,11 @@ def create_kb(vocab, max_entities_per_alias, min_occ, entity_output, count_input
|
|||
print()
|
||||
print("1. _read_wikidata_entities", datetime.datetime.now())
|
||||
print()
|
||||
# title_to_id = _read_wikidata_entities_regex_depr(limit=1000)
|
||||
title_to_id = wd.read_wikidata_entities_json(limit=None)
|
||||
title_to_id, id_to_descr = wd.read_wikidata_entities_json(limit=None)
|
||||
|
||||
# write the title-ID mapping to file
|
||||
# write the title-ID and ID-description mappings to file
|
||||
if write_entity_defs:
|
||||
with open(entity_output, mode='w', encoding='utf8') as entity_file:
|
||||
entity_file.write("WP_title" + "|" + "WD_id" + "\n")
|
||||
for title, qid in title_to_id.items():
|
||||
entity_file.write(title + "|" + str(qid) + "\n")
|
||||
_write_entity_files(entity_def_output, entity_descr_output, title_to_id, id_to_descr)
|
||||
|
||||
title_list = list(title_to_id.keys())
|
||||
entity_list = [title_to_id[x] for x in title_list]
|
||||
|
@ -57,6 +56,41 @@ def create_kb(vocab, max_entities_per_alias, min_occ, entity_output, count_input
|
|||
return kb
|
||||
|
||||
|
||||
def _write_entity_files(entity_def_output, entity_descr_output, title_to_id, id_to_descr):
|
||||
with open(entity_def_output, mode='w', encoding='utf8') as id_file:
|
||||
id_file.write("WP_title" + "|" + "WD_id" + "\n")
|
||||
for title, qid in title_to_id.items():
|
||||
id_file.write(title + "|" + str(qid) + "\n")
|
||||
with open(entity_descr_output, mode='w', encoding='utf8') as descr_file:
|
||||
descr_file.write("WD_id" + "|" + "description" + "\n")
|
||||
for qid, descr in id_to_descr.items():
|
||||
descr_file.write(str(qid) + "|" + descr + "\n")
|
||||
|
||||
|
||||
def _get_entity_to_id(entity_def_output):
|
||||
entity_to_id = dict()
|
||||
with open(entity_def_output, 'r', encoding='utf8') as csvfile:
|
||||
csvreader = csv.reader(csvfile, delimiter='|')
|
||||
# skip header
|
||||
next(csvreader)
|
||||
for row in csvreader:
|
||||
entity_to_id[row[0]] = row[1]
|
||||
|
||||
return entity_to_id
|
||||
|
||||
|
||||
def _get_id_to_description(entity_descr_output):
|
||||
id_to_desc = dict()
|
||||
with open(entity_descr_output, 'r', encoding='utf8') as csvfile:
|
||||
csvreader = csv.reader(csvfile, delimiter='|')
|
||||
# skip header
|
||||
next(csvreader)
|
||||
for row in csvreader:
|
||||
id_to_desc[row[0]] = row[1]
|
||||
|
||||
return id_to_desc
|
||||
|
||||
|
||||
def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_input, to_print=False):
|
||||
wp_titles = title_to_id.keys()
|
||||
|
||||
|
|
|
@ -32,7 +32,7 @@ def run_el_toy_example(nlp, kb):
|
|||
print("ent", ent.text, ent.label_, ent.kb_id_)
|
||||
|
||||
|
||||
def run_el_training(nlp, kb, training_dir, limit=None):
|
||||
def run_el_dev(nlp, kb, training_dir, limit=None):
|
||||
_prepare_pipeline(nlp, kb)
|
||||
|
||||
correct_entries_per_article, _ = training_set_creator.read_training_entities(training_output=training_dir,
|
||||
|
@ -48,7 +48,7 @@ def run_el_training(nlp, kb, training_dir, limit=None):
|
|||
if is_dev(f):
|
||||
article_id = f.replace(".txt", "")
|
||||
if cnt % 500 == 0:
|
||||
print(datetime.datetime.now(), "processed", cnt, "files in the training dataset")
|
||||
print(datetime.datetime.now(), "processed", cnt, "files in the dev dataset")
|
||||
cnt += 1
|
||||
with open(os.path.join(training_dir, f), mode="r", encoding='utf8') as file:
|
||||
text = file.read()
|
||||
|
|
58
examples/pipeline/wiki_entity_linking/train_el.py
Normal file
58
examples/pipeline/wiki_entity_linking/train_el.py
Normal file
|
@ -0,0 +1,58 @@
|
|||
# coding: utf-8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import os
|
||||
import datetime
|
||||
from os import listdir
|
||||
|
||||
from examples.pipeline.wiki_entity_linking import run_el, training_set_creator, kb_creator
|
||||
from examples.pipeline.wiki_entity_linking import wikidata_processor as wd
|
||||
|
||||
""" TODO: this code needs to be implemented in pipes.pyx"""
|
||||
|
||||
|
||||
def train_model(kb, nlp, training_dir, entity_descr_output, limit=None):
|
||||
run_el._prepare_pipeline(nlp, kb)
|
||||
|
||||
correct_entries, incorrect_entries = training_set_creator.read_training_entities(training_output=training_dir,
|
||||
collect_correct=True,
|
||||
collect_incorrect=True)
|
||||
|
||||
entities = kb.get_entity_strings()
|
||||
|
||||
id_to_descr = kb_creator._get_id_to_description(entity_descr_output)
|
||||
|
||||
cnt = 0
|
||||
for f in listdir(training_dir):
|
||||
if not limit or cnt < limit:
|
||||
if not run_el.is_dev(f):
|
||||
article_id = f.replace(".txt", "")
|
||||
if cnt % 500 == 0:
|
||||
print(datetime.datetime.now(), "processed", cnt, "files in the dev dataset")
|
||||
cnt += 1
|
||||
with open(os.path.join(training_dir, f), mode="r", encoding='utf8') as file:
|
||||
text = file.read()
|
||||
print()
|
||||
doc = nlp(text)
|
||||
doc_vector = doc.vector
|
||||
print("FILE", f, len(doc_vector), "D vector")
|
||||
|
||||
for mention_pos, entity_pos in correct_entries[article_id].items():
|
||||
descr = id_to_descr.get(entity_pos)
|
||||
if descr:
|
||||
doc_descr = nlp(descr)
|
||||
descr_vector = doc_descr.vector
|
||||
print("GOLD POS", mention_pos, entity_pos, len(descr_vector), "D vector")
|
||||
|
||||
for mention_neg, entity_negs in incorrect_entries[article_id].items():
|
||||
for entity_neg in entity_negs:
|
||||
descr = id_to_descr.get(entity_neg)
|
||||
if descr:
|
||||
doc_descr = nlp(descr)
|
||||
descr_vector = doc_descr.vector
|
||||
print("GOLD NEG", mention_neg, entity_neg, len(descr_vector), "D vector")
|
||||
|
||||
print()
|
||||
print("Processed", cnt, "dev articles")
|
||||
print()
|
||||
|
|
@ -6,7 +6,7 @@ import csv
|
|||
import bz2
|
||||
import datetime
|
||||
|
||||
from . import wikipedia_processor as wp
|
||||
from . import wikipedia_processor as wp, kb_creator
|
||||
|
||||
"""
|
||||
Process Wikipedia interlinks to generate a training dataset for the EL algorithm
|
||||
|
@ -14,26 +14,15 @@ Process Wikipedia interlinks to generate a training dataset for the EL algorithm
|
|||
|
||||
ENTITY_FILE = "gold_entities.csv"
|
||||
|
||||
def create_training(kb, entity_input, training_output):
|
||||
|
||||
def create_training(kb, entity_def_input, training_output):
|
||||
if not kb:
|
||||
raise ValueError("kb should be defined")
|
||||
# nlp = spacy.load('en_core_web_sm')
|
||||
wp_to_id = _get_entity_to_id(entity_input)
|
||||
wp_to_id = kb_creator._get_entity_to_id(entity_def_input)
|
||||
_process_wikipedia_texts(kb, wp_to_id, training_output, limit=100000000) # TODO: full dataset
|
||||
|
||||
|
||||
def _get_entity_to_id(entity_input):
|
||||
entity_to_id = dict()
|
||||
with open(entity_input, 'r', encoding='utf8') as csvfile:
|
||||
csvreader = csv.reader(csvfile, delimiter='|')
|
||||
# skip header
|
||||
next(csvreader)
|
||||
for row in csvreader:
|
||||
entity_to_id[row[0]] = row[1]
|
||||
|
||||
return entity_to_id
|
||||
|
||||
|
||||
def _process_wikipedia_texts(kb, wp_to_id, training_output, limit=None):
|
||||
"""
|
||||
Read the XML wikipedia data to parse out training data:
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# coding: utf-8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from examples.pipeline.wiki_entity_linking import wikipedia_processor as wp, kb_creator, training_set_creator, run_el
|
||||
from examples.pipeline.wiki_entity_linking import wikipedia_processor as wp, kb_creator, training_set_creator, run_el, train_el
|
||||
|
||||
import spacy
|
||||
from spacy.vocab import Vocab
|
||||
|
@ -15,11 +15,12 @@ Demonstrate how to build a knowledge base from WikiData and run an Entity Linkin
|
|||
PRIOR_PROB = 'C:/Users/Sofie/Documents/data/wikipedia/prior_prob.csv'
|
||||
ENTITY_COUNTS = 'C:/Users/Sofie/Documents/data/wikipedia/entity_freq.csv'
|
||||
ENTITY_DEFS = 'C:/Users/Sofie/Documents/data/wikipedia/entity_defs.csv'
|
||||
ENTITY_DESCR = 'C:/Users/Sofie/Documents/data/wikipedia/entity_descriptions.csv'
|
||||
|
||||
KB_FILE = 'C:/Users/Sofie/Documents/data/wikipedia/kb'
|
||||
VOCAB_DIR = 'C:/Users/Sofie/Documents/data/wikipedia/vocab'
|
||||
|
||||
TRAINING_DIR = 'C:/Users/Sofie/Documents/data/wikipedia/training_nel/'
|
||||
TRAINING_DIR = 'C:/Users/Sofie/Documents/data/wikipedia/training_data_nel/'
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -30,17 +31,20 @@ if __name__ == "__main__":
|
|||
# one-time methods to create KB and write to file
|
||||
to_create_prior_probs = False
|
||||
to_create_entity_counts = False
|
||||
to_create_kb = False
|
||||
to_create_kb = True
|
||||
|
||||
# read KB back in from file
|
||||
to_read_kb = True
|
||||
to_test_kb = False
|
||||
to_test_kb = True
|
||||
|
||||
# create training dataset
|
||||
create_wp_training = False
|
||||
|
||||
# apply named entity linking to the training dataset
|
||||
apply_to_training = True
|
||||
# run training
|
||||
run_training = False
|
||||
|
||||
# apply named entity linking to the dev dataset
|
||||
apply_to_dev = False
|
||||
|
||||
# STEP 1 : create prior probabilities from WP
|
||||
# run only once !
|
||||
|
@ -65,7 +69,8 @@ if __name__ == "__main__":
|
|||
my_kb = kb_creator.create_kb(my_vocab,
|
||||
max_entities_per_alias=10,
|
||||
min_occ=5,
|
||||
entity_output=ENTITY_DEFS,
|
||||
entity_def_output=ENTITY_DEFS,
|
||||
entity_descr_output=ENTITY_DESCR,
|
||||
count_input=ENTITY_COUNTS,
|
||||
prior_prob_input=PRIOR_PROB,
|
||||
to_print=False)
|
||||
|
@ -98,12 +103,19 @@ if __name__ == "__main__":
|
|||
# STEP 5: create a training dataset from WP
|
||||
if create_wp_training:
|
||||
print("STEP 5: create training dataset", datetime.datetime.now())
|
||||
training_set_creator.create_training(kb=my_kb, entity_input=ENTITY_DEFS, training_output=TRAINING_DIR)
|
||||
training_set_creator.create_training(kb=my_kb, entity_def_input=ENTITY_DEFS, training_output=TRAINING_DIR)
|
||||
|
||||
# STEP 6: apply the EL algorithm on the training dataset
|
||||
if apply_to_training:
|
||||
# STEP 7: apply the EL algorithm on the training dataset
|
||||
if run_training:
|
||||
print("STEP 6: training ", datetime.datetime.now())
|
||||
my_nlp = spacy.load('en_core_web_sm')
|
||||
run_el.run_el_training(kb=my_kb, nlp=my_nlp, training_dir=TRAINING_DIR, limit=1000)
|
||||
train_el.train_model(kb=my_kb, nlp=my_nlp, training_dir=TRAINING_DIR, entity_descr_output=ENTITY_DESCR, limit=5)
|
||||
print()
|
||||
|
||||
# STEP 8: apply the EL algorithm on the dev dataset
|
||||
if apply_to_dev:
|
||||
my_nlp = spacy.load('en_core_web_sm')
|
||||
run_el.run_el_dev(kb=my_kb, nlp=my_nlp, training_dir=TRAINING_DIR, limit=2000)
|
||||
print()
|
||||
|
||||
|
||||
|
|
|
@ -13,17 +13,18 @@ WIKIDATA_JSON = 'C:/Users/Sofie/Documents/data/wikidata/wikidata-20190304-all.js
|
|||
def read_wikidata_entities_json(limit=None, to_print=False):
|
||||
""" Read the JSON wiki data and parse out the entities. Takes about 7u30 to parse 55M lines. """
|
||||
|
||||
languages = {'en', 'de'}
|
||||
lang = 'en'
|
||||
prop_filter = {'P31': {'Q5', 'Q15632617'}} # currently defined as OR: one property suffices to be selected
|
||||
site_filter = 'enwiki'
|
||||
|
||||
title_to_id = dict()
|
||||
id_to_descr = dict()
|
||||
|
||||
# parse appropriate fields - depending on what we need in the KB
|
||||
parse_properties = False
|
||||
parse_sitelinks = True
|
||||
parse_labels = False
|
||||
parse_descriptions = False
|
||||
parse_descriptions = True
|
||||
parse_aliases = False
|
||||
|
||||
with bz2.open(WIKIDATA_JSON, mode='rb') as file:
|
||||
|
@ -76,91 +77,36 @@ def read_wikidata_entities_json(limit=None, to_print=False):
|
|||
if to_print:
|
||||
print(site_filter, ":", site)
|
||||
title_to_id[site] = unique_id
|
||||
# print(site, "for", unique_id)
|
||||
|
||||
if parse_labels:
|
||||
labels = obj["labels"]
|
||||
if labels:
|
||||
for lang in languages:
|
||||
lang_label = labels.get(lang, None)
|
||||
if lang_label:
|
||||
if to_print:
|
||||
print("label (" + lang + "):", lang_label["value"])
|
||||
lang_label = labels.get(lang, None)
|
||||
if lang_label:
|
||||
if to_print:
|
||||
print("label (" + lang + "):", lang_label["value"])
|
||||
|
||||
if parse_descriptions:
|
||||
descriptions = obj["descriptions"]
|
||||
if descriptions:
|
||||
for lang in languages:
|
||||
lang_descr = descriptions.get(lang, None)
|
||||
if lang_descr:
|
||||
if to_print:
|
||||
print("description (" + lang + "):", lang_descr["value"])
|
||||
lang_descr = descriptions.get(lang, None)
|
||||
if lang_descr:
|
||||
if to_print:
|
||||
print("description (" + lang + "):", lang_descr["value"])
|
||||
id_to_descr[unique_id] = lang_descr["value"]
|
||||
|
||||
if parse_aliases:
|
||||
aliases = obj["aliases"]
|
||||
if aliases:
|
||||
for lang in languages:
|
||||
lang_aliases = aliases.get(lang, None)
|
||||
if lang_aliases:
|
||||
for item in lang_aliases:
|
||||
if to_print:
|
||||
print("alias (" + lang + "):", item["value"])
|
||||
lang_aliases = aliases.get(lang, None)
|
||||
if lang_aliases:
|
||||
for item in lang_aliases:
|
||||
if to_print:
|
||||
print("alias (" + lang + "):", item["value"])
|
||||
|
||||
if to_print:
|
||||
print()
|
||||
line = file.readline()
|
||||
cnt += 1
|
||||
|
||||
return title_to_id
|
||||
|
||||
|
||||
def _read_wikidata_entities_regex_depr(limit=None):
|
||||
"""
|
||||
Read the JSON wiki data and parse out the entities with regular expressions. Takes XXX to parse 55M lines.
|
||||
TODO: doesn't work yet. may be deleted ?
|
||||
"""
|
||||
|
||||
regex_p31 = re.compile(r'mainsnak[^}]*\"P31\"[^}]*}', re.UNICODE)
|
||||
regex_id = re.compile(r'\"id\":"Q[0-9]*"', re.UNICODE)
|
||||
regex_enwiki = re.compile(r'\"enwiki\":[^}]*}', re.UNICODE)
|
||||
regex_title = re.compile(r'\"title\":"[^"]*"', re.UNICODE)
|
||||
|
||||
title_to_id = dict()
|
||||
|
||||
with bz2.open(WIKIDATA_JSON, mode='rb') as file:
|
||||
line = file.readline()
|
||||
cnt = 0
|
||||
while line and (not limit or cnt < limit):
|
||||
if cnt % 500000 == 0:
|
||||
print(datetime.datetime.now(), "processed", cnt, "lines of WikiData dump")
|
||||
clean_line = line.strip()
|
||||
if clean_line.endswith(b","):
|
||||
clean_line = clean_line[:-1]
|
||||
if len(clean_line) > 1:
|
||||
clean_line = line.strip().decode("utf-8")
|
||||
keep = False
|
||||
|
||||
p31_matches = regex_p31.findall(clean_line)
|
||||
if p31_matches:
|
||||
for p31_match in p31_matches:
|
||||
id_matches = regex_id.findall(p31_match)
|
||||
for id_match in id_matches:
|
||||
id_match = id_match[6:][:-1]
|
||||
if id_match == "Q5" or id_match == "Q15632617":
|
||||
keep = True
|
||||
|
||||
if keep:
|
||||
id_match = regex_id.search(clean_line).group(0)
|
||||
id_match = id_match[6:][:-1]
|
||||
|
||||
enwiki_matches = regex_enwiki.findall(clean_line)
|
||||
if enwiki_matches:
|
||||
for enwiki_match in enwiki_matches:
|
||||
title_match = regex_title.search(enwiki_match).group(0)
|
||||
title = title_match[9:][:-1]
|
||||
title_to_id[title] = id_match
|
||||
|
||||
line = file.readline()
|
||||
cnt += 1
|
||||
|
||||
return title_to_id
|
||||
return title_to_id, id_to_descr
|
||||
|
|
Loading…
Reference in New Issue
Block a user