mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 18:06:29 +03:00
filter training data beforehand (+black formatting)
This commit is contained in:
parent
d833d4c358
commit
ec55d2fccd
|
@ -18,6 +18,10 @@ Gold-standard entities are stored in one file in standoff format (by character o
|
||||||
ENTITY_FILE = "gold_entities.csv"
|
ENTITY_FILE = "gold_entities.csv"
|
||||||
|
|
||||||
|
|
||||||
|
def now():
|
||||||
|
return datetime.datetime.now()
|
||||||
|
|
||||||
|
|
||||||
def create_training(wikipedia_input, entity_def_input, training_output):
|
def create_training(wikipedia_input, entity_def_input, training_output):
|
||||||
wp_to_id = kb_creator.get_entity_to_id(entity_def_input)
|
wp_to_id = kb_creator.get_entity_to_id(entity_def_input)
|
||||||
_process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=None)
|
_process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=None)
|
||||||
|
@ -54,12 +58,7 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
|
||||||
reading_revision = False
|
reading_revision = False
|
||||||
while line and (not limit or cnt < limit):
|
while line and (not limit or cnt < limit):
|
||||||
if cnt % 1000000 == 0:
|
if cnt % 1000000 == 0:
|
||||||
print(
|
print(now(), "processed", cnt, "lines of Wikipedia dump")
|
||||||
datetime.datetime.now(),
|
|
||||||
"processed",
|
|
||||||
cnt,
|
|
||||||
"lines of Wikipedia dump",
|
|
||||||
)
|
|
||||||
clean_line = line.strip().decode("utf-8")
|
clean_line = line.strip().decode("utf-8")
|
||||||
|
|
||||||
if clean_line == "<revision>":
|
if clean_line == "<revision>":
|
||||||
|
@ -328,8 +327,9 @@ def is_dev(article_id):
|
||||||
|
|
||||||
def read_training(nlp, training_dir, dev, limit, kb=None):
|
def read_training(nlp, training_dir, dev, limit, kb=None):
|
||||||
""" This method provides training examples that correspond to the entity annotations found by the nlp object.
|
""" This method provides training examples that correspond to the entity annotations found by the nlp object.
|
||||||
When kb is provided, it will include also negative training examples by using the candidate generator.
|
When kb is provided (for training), it will include negative training examples by using the candidate generator,
|
||||||
When kb=None, it will only include positive training examples."""
|
and it will only keep positive training examples that can be found in the KB.
|
||||||
|
When kb=None (for testing), it will include all positive examples only."""
|
||||||
entityfile_loc = training_dir / ENTITY_FILE
|
entityfile_loc = training_dir / ENTITY_FILE
|
||||||
data = []
|
data = []
|
||||||
|
|
||||||
|
@ -402,12 +402,11 @@ def read_training(nlp, training_dir, dev, limit, kb=None):
|
||||||
gold_end = int(end) - found_ent.sent.start_char
|
gold_end = int(end) - found_ent.sent.start_char
|
||||||
|
|
||||||
# add both pos and neg examples (in random order)
|
# add both pos and neg examples (in random order)
|
||||||
|
# this will exclude examples not in the KB
|
||||||
if kb:
|
if kb:
|
||||||
gold_entities = {}
|
gold_entities = {}
|
||||||
candidates = kb.get_candidates(alias)
|
candidates = kb.get_candidates(alias)
|
||||||
candidate_ids = [c.entity_ for c in candidates]
|
candidate_ids = [c.entity_ for c in candidates]
|
||||||
# add positive example in case the KB doesn't have it
|
|
||||||
candidate_ids.append(wd_id)
|
|
||||||
random.shuffle(candidate_ids)
|
random.shuffle(candidate_ids)
|
||||||
for kb_id in candidate_ids:
|
for kb_id in candidate_ids:
|
||||||
entry = (gold_start, gold_end, kb_id)
|
entry = (gold_start, gold_end, kb_id)
|
||||||
|
@ -415,6 +414,7 @@ def read_training(nlp, training_dir, dev, limit, kb=None):
|
||||||
gold_entities[entry] = 0.0
|
gold_entities[entry] = 0.0
|
||||||
else:
|
else:
|
||||||
gold_entities[entry] = 1.0
|
gold_entities[entry] = 1.0
|
||||||
|
# keep all positive examples
|
||||||
else:
|
else:
|
||||||
entry = (gold_start, gold_end, wd_id)
|
entry = (gold_start, gold_end, wd_id)
|
||||||
gold_entities = {entry: 1.0}
|
gold_entities = {entry: 1.0}
|
||||||
|
|
|
@ -14,22 +14,97 @@ Write these results to file for downstream KB and training data generation.
|
||||||
map_alias_to_link = dict()
|
map_alias_to_link = dict()
|
||||||
|
|
||||||
# these will/should be matched ignoring case
|
# these will/should be matched ignoring case
|
||||||
wiki_namespaces = ["b", "betawikiversity", "Book", "c", "Category", "Commons",
|
wiki_namespaces = [
|
||||||
"d", "dbdump", "download", "Draft", "Education", "Foundation",
|
"b",
|
||||||
"Gadget", "Gadget definition", "gerrit", "File", "Help", "Image", "Incubator",
|
"betawikiversity",
|
||||||
"m", "mail", "mailarchive", "media", "MediaWiki", "MediaWiki talk", "Mediawikiwiki",
|
"Book",
|
||||||
"MediaZilla", "Meta", "Metawikipedia", "Module",
|
"c",
|
||||||
"mw", "n", "nost", "oldwikisource", "outreach", "outreachwiki", "otrs", "OTRSwiki",
|
"Category",
|
||||||
"Portal", "phab", "Phabricator", "Project", "q", "quality", "rev",
|
"Commons",
|
||||||
"s", "spcom", "Special", "species", "Strategy", "sulutil", "svn",
|
"d",
|
||||||
"Talk", "Template", "Template talk", "Testwiki", "ticket", "TimedText", "Toollabs", "tools",
|
"dbdump",
|
||||||
"tswiki", "User", "User talk", "v", "voy",
|
"download",
|
||||||
"w", "Wikibooks", "Wikidata", "wikiHow", "Wikinvest", "wikilivres", "Wikimedia", "Wikinews",
|
"Draft",
|
||||||
"Wikipedia", "Wikipedia talk", "Wikiquote", "Wikisource", "Wikispecies", "Wikitech",
|
"Education",
|
||||||
"Wikiversity", "Wikivoyage", "wikt", "wiktionary", "wmf", "wmania", "WP"]
|
"Foundation",
|
||||||
|
"Gadget",
|
||||||
|
"Gadget definition",
|
||||||
|
"gerrit",
|
||||||
|
"File",
|
||||||
|
"Help",
|
||||||
|
"Image",
|
||||||
|
"Incubator",
|
||||||
|
"m",
|
||||||
|
"mail",
|
||||||
|
"mailarchive",
|
||||||
|
"media",
|
||||||
|
"MediaWiki",
|
||||||
|
"MediaWiki talk",
|
||||||
|
"Mediawikiwiki",
|
||||||
|
"MediaZilla",
|
||||||
|
"Meta",
|
||||||
|
"Metawikipedia",
|
||||||
|
"Module",
|
||||||
|
"mw",
|
||||||
|
"n",
|
||||||
|
"nost",
|
||||||
|
"oldwikisource",
|
||||||
|
"outreach",
|
||||||
|
"outreachwiki",
|
||||||
|
"otrs",
|
||||||
|
"OTRSwiki",
|
||||||
|
"Portal",
|
||||||
|
"phab",
|
||||||
|
"Phabricator",
|
||||||
|
"Project",
|
||||||
|
"q",
|
||||||
|
"quality",
|
||||||
|
"rev",
|
||||||
|
"s",
|
||||||
|
"spcom",
|
||||||
|
"Special",
|
||||||
|
"species",
|
||||||
|
"Strategy",
|
||||||
|
"sulutil",
|
||||||
|
"svn",
|
||||||
|
"Talk",
|
||||||
|
"Template",
|
||||||
|
"Template talk",
|
||||||
|
"Testwiki",
|
||||||
|
"ticket",
|
||||||
|
"TimedText",
|
||||||
|
"Toollabs",
|
||||||
|
"tools",
|
||||||
|
"tswiki",
|
||||||
|
"User",
|
||||||
|
"User talk",
|
||||||
|
"v",
|
||||||
|
"voy",
|
||||||
|
"w",
|
||||||
|
"Wikibooks",
|
||||||
|
"Wikidata",
|
||||||
|
"wikiHow",
|
||||||
|
"Wikinvest",
|
||||||
|
"wikilivres",
|
||||||
|
"Wikimedia",
|
||||||
|
"Wikinews",
|
||||||
|
"Wikipedia",
|
||||||
|
"Wikipedia talk",
|
||||||
|
"Wikiquote",
|
||||||
|
"Wikisource",
|
||||||
|
"Wikispecies",
|
||||||
|
"Wikitech",
|
||||||
|
"Wikiversity",
|
||||||
|
"Wikivoyage",
|
||||||
|
"wikt",
|
||||||
|
"wiktionary",
|
||||||
|
"wmf",
|
||||||
|
"wmania",
|
||||||
|
"WP",
|
||||||
|
]
|
||||||
|
|
||||||
# find the links
|
# find the links
|
||||||
link_regex = re.compile(r'\[\[[^\[\]]*\]\]')
|
link_regex = re.compile(r"\[\[[^\[\]]*\]\]")
|
||||||
|
|
||||||
# match on interwiki links, e.g. `en:` or `:fr:`
|
# match on interwiki links, e.g. `en:` or `:fr:`
|
||||||
ns_regex = r":?" + "[a-z][a-z]" + ":"
|
ns_regex = r":?" + "[a-z][a-z]" + ":"
|
||||||
|
@ -41,18 +116,22 @@ for ns in wiki_namespaces:
|
||||||
ns_regex = re.compile(ns_regex, re.IGNORECASE)
|
ns_regex = re.compile(ns_regex, re.IGNORECASE)
|
||||||
|
|
||||||
|
|
||||||
def read_wikipedia_prior_probs(wikipedia_input, prior_prob_output):
|
def now():
|
||||||
|
return datetime.datetime.now()
|
||||||
|
|
||||||
|
|
||||||
|
def read_prior_probs(wikipedia_input, prior_prob_output):
|
||||||
"""
|
"""
|
||||||
Read the XML wikipedia data and parse out intra-wiki links to estimate prior probabilities.
|
Read the XML wikipedia data and parse out intra-wiki links to estimate prior probabilities.
|
||||||
The full file takes about 2h to parse 1100M lines.
|
The full file takes about 2h to parse 1100M lines.
|
||||||
It works relatively fast because it runs line by line, irrelevant of which article the intrawiki is from.
|
It works relatively fast because it runs line by line, irrelevant of which article the intrawiki is from.
|
||||||
"""
|
"""
|
||||||
with bz2.open(wikipedia_input, mode='rb') as file:
|
with bz2.open(wikipedia_input, mode="rb") as file:
|
||||||
line = file.readline()
|
line = file.readline()
|
||||||
cnt = 0
|
cnt = 0
|
||||||
while line:
|
while line:
|
||||||
if cnt % 5000000 == 0:
|
if cnt % 5000000 == 0:
|
||||||
print(datetime.datetime.now(), "processed", cnt, "lines of Wikipedia dump")
|
print(now(), "processed", cnt, "lines of Wikipedia dump")
|
||||||
clean_line = line.strip().decode("utf-8")
|
clean_line = line.strip().decode("utf-8")
|
||||||
|
|
||||||
aliases, entities, normalizations = get_wp_links(clean_line)
|
aliases, entities, normalizations = get_wp_links(clean_line)
|
||||||
|
@ -64,10 +143,11 @@ def read_wikipedia_prior_probs(wikipedia_input, prior_prob_output):
|
||||||
cnt += 1
|
cnt += 1
|
||||||
|
|
||||||
# write all aliases and their entities and count occurrences to file
|
# write all aliases and their entities and count occurrences to file
|
||||||
with open(prior_prob_output, mode='w', encoding='utf8') as outputfile:
|
with open(prior_prob_output, mode="w", encoding="utf8") as outputfile:
|
||||||
outputfile.write("alias" + "|" + "count" + "|" + "entity" + "\n")
|
outputfile.write("alias" + "|" + "count" + "|" + "entity" + "\n")
|
||||||
for alias, alias_dict in sorted(map_alias_to_link.items(), key=lambda x: x[0]):
|
for alias, alias_dict in sorted(map_alias_to_link.items(), key=lambda x: x[0]):
|
||||||
for entity, count in sorted(alias_dict.items(), key=lambda x: x[1], reverse=True):
|
s_dict = sorted(alias_dict.items(), key=lambda x: x[1], reverse=True)
|
||||||
|
for entity, count in s_dict:
|
||||||
outputfile.write(alias + "|" + str(count) + "|" + entity + "\n")
|
outputfile.write(alias + "|" + str(count) + "|" + entity + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
@ -140,13 +220,13 @@ def write_entity_counts(prior_prob_input, count_output, to_print=False):
|
||||||
entity_to_count = dict()
|
entity_to_count = dict()
|
||||||
total_count = 0
|
total_count = 0
|
||||||
|
|
||||||
with open(prior_prob_input, mode='r', encoding='utf8') as prior_file:
|
with open(prior_prob_input, mode="r", encoding="utf8") as prior_file:
|
||||||
# skip header
|
# skip header
|
||||||
prior_file.readline()
|
prior_file.readline()
|
||||||
line = prior_file.readline()
|
line = prior_file.readline()
|
||||||
|
|
||||||
while line:
|
while line:
|
||||||
splits = line.replace('\n', "").split(sep='|')
|
splits = line.replace("\n", "").split(sep="|")
|
||||||
# alias = splits[0]
|
# alias = splits[0]
|
||||||
count = int(splits[1])
|
count = int(splits[1])
|
||||||
entity = splits[2]
|
entity = splits[2]
|
||||||
|
@ -158,7 +238,7 @@ def write_entity_counts(prior_prob_input, count_output, to_print=False):
|
||||||
|
|
||||||
line = prior_file.readline()
|
line = prior_file.readline()
|
||||||
|
|
||||||
with open(count_output, mode='w', encoding='utf8') as entity_file:
|
with open(count_output, mode="w", encoding="utf8") as entity_file:
|
||||||
entity_file.write("entity" + "|" + "count" + "\n")
|
entity_file.write("entity" + "|" + "count" + "\n")
|
||||||
for entity, count in entity_to_count.items():
|
for entity, count in entity_to_count.items():
|
||||||
entity_file.write(entity + "|" + str(count) + "\n")
|
entity_file.write(entity + "|" + str(count) + "\n")
|
||||||
|
@ -171,12 +251,11 @@ def write_entity_counts(prior_prob_input, count_output, to_print=False):
|
||||||
|
|
||||||
def get_all_frequencies(count_input):
|
def get_all_frequencies(count_input):
|
||||||
entity_to_count = dict()
|
entity_to_count = dict()
|
||||||
with open(count_input, 'r', encoding='utf8') as csvfile:
|
with open(count_input, "r", encoding="utf8") as csvfile:
|
||||||
csvreader = csv.reader(csvfile, delimiter='|')
|
csvreader = csv.reader(csvfile, delimiter="|")
|
||||||
# skip header
|
# skip header
|
||||||
next(csvreader)
|
next(csvreader)
|
||||||
for row in csvreader:
|
for row in csvreader:
|
||||||
entity_to_count[row[0]] = int(row[1])
|
entity_to_count[row[0]] = int(row[1])
|
||||||
|
|
||||||
return entity_to_count
|
return entity_to_count
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,8 @@ import random
|
||||||
import datetime
|
import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from bin.wiki_entity_linking import training_set_creator, kb_creator, wikipedia_processor as wp
|
from bin.wiki_entity_linking import wikipedia_processor as wp
|
||||||
|
from bin.wiki_entity_linking import training_set_creator, kb_creator
|
||||||
from bin.wiki_entity_linking.kb_creator import DESC_WIDTH
|
from bin.wiki_entity_linking.kb_creator import DESC_WIDTH
|
||||||
|
|
||||||
import spacy
|
import spacy
|
||||||
|
@ -17,23 +18,25 @@ Demonstrate how to build a knowledge base from WikiData and run an Entity Linkin
|
||||||
"""
|
"""
|
||||||
|
|
||||||
ROOT_DIR = Path("C:/Users/Sofie/Documents/data/")
|
ROOT_DIR = Path("C:/Users/Sofie/Documents/data/")
|
||||||
OUTPUT_DIR = ROOT_DIR / 'wikipedia'
|
OUTPUT_DIR = ROOT_DIR / "wikipedia"
|
||||||
TRAINING_DIR = OUTPUT_DIR / 'training_data_nel'
|
TRAINING_DIR = OUTPUT_DIR / "training_data_nel"
|
||||||
|
|
||||||
PRIOR_PROB = OUTPUT_DIR / 'prior_prob.csv'
|
PRIOR_PROB = OUTPUT_DIR / "prior_prob.csv"
|
||||||
ENTITY_COUNTS = OUTPUT_DIR / 'entity_freq.csv'
|
ENTITY_COUNTS = OUTPUT_DIR / "entity_freq.csv"
|
||||||
ENTITY_DEFS = OUTPUT_DIR / 'entity_defs.csv'
|
ENTITY_DEFS = OUTPUT_DIR / "entity_defs.csv"
|
||||||
ENTITY_DESCR = OUTPUT_DIR / 'entity_descriptions.csv'
|
ENTITY_DESCR = OUTPUT_DIR / "entity_descriptions.csv"
|
||||||
|
|
||||||
KB_FILE = OUTPUT_DIR / 'kb_1' / 'kb'
|
KB_FILE = OUTPUT_DIR / "kb_1" / "kb"
|
||||||
NLP_1_DIR = OUTPUT_DIR / 'nlp_1'
|
NLP_1_DIR = OUTPUT_DIR / "nlp_1"
|
||||||
NLP_2_DIR = OUTPUT_DIR / 'nlp_2'
|
NLP_2_DIR = OUTPUT_DIR / "nlp_2"
|
||||||
|
|
||||||
# get latest-all.json.bz2 from https://dumps.wikimedia.org/wikidatawiki/entities/
|
# get latest-all.json.bz2 from https://dumps.wikimedia.org/wikidatawiki/entities/
|
||||||
WIKIDATA_JSON = ROOT_DIR / 'wikidata' / 'wikidata-20190304-all.json.bz2'
|
WIKIDATA_JSON = ROOT_DIR / "wikidata" / "wikidata-20190304-all.json.bz2"
|
||||||
|
|
||||||
# get enwiki-latest-pages-articles-multistream.xml.bz2 from https://dumps.wikimedia.org/enwiki/latest/
|
# get enwiki-latest-pages-articles-multistream.xml.bz2 from https://dumps.wikimedia.org/enwiki/latest/
|
||||||
ENWIKI_DUMP = ROOT_DIR / 'wikipedia' / 'enwiki-20190320-pages-articles-multistream.xml.bz2'
|
ENWIKI_DUMP = (
|
||||||
|
ROOT_DIR / "wikipedia" / "enwiki-20190320-pages-articles-multistream.xml.bz2"
|
||||||
|
)
|
||||||
|
|
||||||
# KB construction parameters
|
# KB construction parameters
|
||||||
MAX_CANDIDATES = 10
|
MAX_CANDIDATES = 10
|
||||||
|
@ -48,11 +51,15 @@ L2 = 1e-6
|
||||||
CONTEXT_WIDTH = 128
|
CONTEXT_WIDTH = 128
|
||||||
|
|
||||||
|
|
||||||
|
def now():
|
||||||
|
return datetime.datetime.now()
|
||||||
|
|
||||||
|
|
||||||
def run_pipeline():
|
def run_pipeline():
|
||||||
# set the appropriate booleans to define which parts of the pipeline should be re(run)
|
# set the appropriate booleans to define which parts of the pipeline should be re(run)
|
||||||
print("START", datetime.datetime.now())
|
print("START", now())
|
||||||
print()
|
print()
|
||||||
nlp_1 = spacy.load('en_core_web_lg')
|
nlp_1 = spacy.load("en_core_web_lg")
|
||||||
nlp_2 = None
|
nlp_2 = None
|
||||||
kb_2 = None
|
kb_2 = None
|
||||||
|
|
||||||
|
@ -82,20 +89,21 @@ def run_pipeline():
|
||||||
|
|
||||||
# STEP 1 : create prior probabilities from WP (run only once)
|
# STEP 1 : create prior probabilities from WP (run only once)
|
||||||
if to_create_prior_probs:
|
if to_create_prior_probs:
|
||||||
print("STEP 1: to_create_prior_probs", datetime.datetime.now())
|
print("STEP 1: to_create_prior_probs", now())
|
||||||
wp.read_wikipedia_prior_probs(wikipedia_input=ENWIKI_DUMP, prior_prob_output=PRIOR_PROB)
|
wp.read_prior_probs(ENWIKI_DUMP, PRIOR_PROB)
|
||||||
print()
|
print()
|
||||||
|
|
||||||
# STEP 2 : deduce entity frequencies from WP (run only once)
|
# STEP 2 : deduce entity frequencies from WP (run only once)
|
||||||
if to_create_entity_counts:
|
if to_create_entity_counts:
|
||||||
print("STEP 2: to_create_entity_counts", datetime.datetime.now())
|
print("STEP 2: to_create_entity_counts", now())
|
||||||
wp.write_entity_counts(prior_prob_input=PRIOR_PROB, count_output=ENTITY_COUNTS, to_print=False)
|
wp.write_entity_counts(PRIOR_PROB, ENTITY_COUNTS, to_print=False)
|
||||||
print()
|
print()
|
||||||
|
|
||||||
# STEP 3 : create KB and write to file (run only once)
|
# STEP 3 : create KB and write to file (run only once)
|
||||||
if to_create_kb:
|
if to_create_kb:
|
||||||
print("STEP 3a: to_create_kb", datetime.datetime.now())
|
print("STEP 3a: to_create_kb", now())
|
||||||
kb_1 = kb_creator.create_kb(nlp_1,
|
kb_1 = kb_creator.create_kb(
|
||||||
|
nlp=nlp_1,
|
||||||
max_entities_per_alias=MAX_CANDIDATES,
|
max_entities_per_alias=MAX_CANDIDATES,
|
||||||
min_entity_freq=MIN_ENTITY_FREQ,
|
min_entity_freq=MIN_ENTITY_FREQ,
|
||||||
min_occ=MIN_PAIR_OCC,
|
min_occ=MIN_PAIR_OCC,
|
||||||
|
@ -103,19 +111,20 @@ def run_pipeline():
|
||||||
entity_descr_output=ENTITY_DESCR,
|
entity_descr_output=ENTITY_DESCR,
|
||||||
count_input=ENTITY_COUNTS,
|
count_input=ENTITY_COUNTS,
|
||||||
prior_prob_input=PRIOR_PROB,
|
prior_prob_input=PRIOR_PROB,
|
||||||
wikidata_input=WIKIDATA_JSON)
|
wikidata_input=WIKIDATA_JSON,
|
||||||
|
)
|
||||||
print("kb entities:", kb_1.get_size_entities())
|
print("kb entities:", kb_1.get_size_entities())
|
||||||
print("kb aliases:", kb_1.get_size_aliases())
|
print("kb aliases:", kb_1.get_size_aliases())
|
||||||
print()
|
print()
|
||||||
|
|
||||||
print("STEP 3b: write KB and NLP", datetime.datetime.now())
|
print("STEP 3b: write KB and NLP", now())
|
||||||
kb_1.dump(KB_FILE)
|
kb_1.dump(KB_FILE)
|
||||||
nlp_1.to_disk(NLP_1_DIR)
|
nlp_1.to_disk(NLP_1_DIR)
|
||||||
print()
|
print()
|
||||||
|
|
||||||
# STEP 4 : read KB back in from file
|
# STEP 4 : read KB back in from file
|
||||||
if to_read_kb:
|
if to_read_kb:
|
||||||
print("STEP 4: to_read_kb", datetime.datetime.now())
|
print("STEP 4: to_read_kb", now())
|
||||||
nlp_2 = spacy.load(NLP_1_DIR)
|
nlp_2 = spacy.load(NLP_1_DIR)
|
||||||
kb_2 = KnowledgeBase(vocab=nlp_2.vocab, entity_vector_length=DESC_WIDTH)
|
kb_2 = KnowledgeBase(vocab=nlp_2.vocab, entity_vector_length=DESC_WIDTH)
|
||||||
kb_2.load_bulk(KB_FILE)
|
kb_2.load_bulk(KB_FILE)
|
||||||
|
@ -130,20 +139,26 @@ def run_pipeline():
|
||||||
|
|
||||||
# STEP 5: create a training dataset from WP
|
# STEP 5: create a training dataset from WP
|
||||||
if create_wp_training:
|
if create_wp_training:
|
||||||
print("STEP 5: create training dataset", datetime.datetime.now())
|
print("STEP 5: create training dataset", now())
|
||||||
training_set_creator.create_training(wikipedia_input=ENWIKI_DUMP,
|
training_set_creator.create_training(
|
||||||
|
wikipedia_input=ENWIKI_DUMP,
|
||||||
entity_def_input=ENTITY_DEFS,
|
entity_def_input=ENTITY_DEFS,
|
||||||
training_output=TRAINING_DIR)
|
training_output=TRAINING_DIR,
|
||||||
|
)
|
||||||
|
|
||||||
# STEP 6: create and train the entity linking pipe
|
# STEP 6: create and train the entity linking pipe
|
||||||
if train_pipe:
|
if train_pipe:
|
||||||
print("STEP 6: training Entity Linking pipe", datetime.datetime.now())
|
print("STEP 6: training Entity Linking pipe", now())
|
||||||
type_to_int = {label: i for i, label in enumerate(nlp_2.entity.labels)}
|
type_to_int = {label: i for i, label in enumerate(nlp_2.entity.labels)}
|
||||||
print(" -analysing", len(type_to_int), "different entity types")
|
print(" -analysing", len(type_to_int), "different entity types")
|
||||||
el_pipe = nlp_2.create_pipe(name='entity_linker',
|
el_pipe = nlp_2.create_pipe(
|
||||||
config={"context_width": CONTEXT_WIDTH,
|
name="entity_linker",
|
||||||
|
config={
|
||||||
|
"context_width": CONTEXT_WIDTH,
|
||||||
"pretrained_vectors": nlp_2.vocab.vectors.name,
|
"pretrained_vectors": nlp_2.vocab.vectors.name,
|
||||||
"type_to_int": type_to_int})
|
"type_to_int": type_to_int,
|
||||||
|
},
|
||||||
|
)
|
||||||
el_pipe.set_kb(kb_2)
|
el_pipe.set_kb(kb_2)
|
||||||
nlp_2.add_pipe(el_pipe, last=True)
|
nlp_2.add_pipe(el_pipe, last=True)
|
||||||
|
|
||||||
|
@ -157,18 +172,22 @@ def run_pipeline():
|
||||||
train_limit = 5000
|
train_limit = 5000
|
||||||
dev_limit = 5000
|
dev_limit = 5000
|
||||||
|
|
||||||
train_data = training_set_creator.read_training(nlp=nlp_2,
|
# for training, get pos & neg instances that correspond to entries in the kb
|
||||||
|
train_data = training_set_creator.read_training(
|
||||||
|
nlp=nlp_2,
|
||||||
training_dir=TRAINING_DIR,
|
training_dir=TRAINING_DIR,
|
||||||
dev=False,
|
dev=False,
|
||||||
limit=train_limit)
|
limit=train_limit,
|
||||||
|
kb=el_pipe.kb,
|
||||||
|
)
|
||||||
|
|
||||||
print("Training on", len(train_data), "articles")
|
print("Training on", len(train_data), "articles")
|
||||||
print()
|
print()
|
||||||
|
|
||||||
dev_data = training_set_creator.read_training(nlp=nlp_2,
|
# for testing, get all pos instances, whether or not they are in the kb
|
||||||
training_dir=TRAINING_DIR,
|
dev_data = training_set_creator.read_training(
|
||||||
dev=True,
|
nlp=nlp_2, training_dir=TRAINING_DIR, dev=True, limit=dev_limit, kb=None
|
||||||
limit=dev_limit)
|
)
|
||||||
|
|
||||||
print("Dev testing on", len(dev_data), "articles")
|
print("Dev testing on", len(dev_data), "articles")
|
||||||
print()
|
print()
|
||||||
|
@ -187,8 +206,8 @@ def run_pipeline():
|
||||||
try:
|
try:
|
||||||
docs, golds = zip(*batch)
|
docs, golds = zip(*batch)
|
||||||
nlp_2.update(
|
nlp_2.update(
|
||||||
docs,
|
docs=docs,
|
||||||
golds,
|
golds=golds,
|
||||||
sgd=optimizer,
|
sgd=optimizer,
|
||||||
drop=DROPOUT,
|
drop=DROPOUT,
|
||||||
losses=losses,
|
losses=losses,
|
||||||
|
@ -200,48 +219,61 @@ def run_pipeline():
|
||||||
if batchnr > 0:
|
if batchnr > 0:
|
||||||
el_pipe.cfg["context_weight"] = 1
|
el_pipe.cfg["context_weight"] = 1
|
||||||
el_pipe.cfg["prior_weight"] = 1
|
el_pipe.cfg["prior_weight"] = 1
|
||||||
dev_acc_context, dev_acc_context_dict = _measure_accuracy(dev_data, el_pipe)
|
dev_acc_context, _ = _measure_acc(dev_data, el_pipe)
|
||||||
losses['entity_linker'] = losses['entity_linker'] / batchnr
|
losses["entity_linker"] = losses["entity_linker"] / batchnr
|
||||||
print("Epoch, train loss", itn, round(losses['entity_linker'], 2),
|
print(
|
||||||
" / dev acc avg", round(dev_acc_context, 3))
|
"Epoch, train loss",
|
||||||
|
itn,
|
||||||
|
round(losses["entity_linker"], 2),
|
||||||
|
" / dev acc avg",
|
||||||
|
round(dev_acc_context, 3),
|
||||||
|
)
|
||||||
|
|
||||||
# STEP 7: measure the performance of our trained pipe on an independent dev set
|
# STEP 7: measure the performance of our trained pipe on an independent dev set
|
||||||
if len(dev_data) and measure_performance:
|
if len(dev_data) and measure_performance:
|
||||||
print()
|
print()
|
||||||
print("STEP 7: performance measurement of Entity Linking pipe", datetime.datetime.now())
|
print("STEP 7: performance measurement of Entity Linking pipe", now())
|
||||||
print()
|
print()
|
||||||
|
|
||||||
counts, acc_r, acc_r_label, acc_p, acc_p_label, acc_o, acc_o_label = _measure_baselines(dev_data, kb_2)
|
counts, acc_r, acc_r_d, acc_p, acc_p_d, acc_o, acc_o_d = _measure_baselines(
|
||||||
|
dev_data, kb_2
|
||||||
|
)
|
||||||
print("dev counts:", sorted(counts.items(), key=lambda x: x[0]))
|
print("dev counts:", sorted(counts.items(), key=lambda x: x[0]))
|
||||||
print("dev acc oracle:", round(acc_o, 3), [(x, round(y, 3)) for x, y in acc_o_label.items()])
|
|
||||||
print("dev acc random:", round(acc_r, 3), [(x, round(y, 3)) for x, y in acc_r_label.items()])
|
oracle_by_label = [(x, round(y, 3)) for x, y in acc_o_d.items()]
|
||||||
print("dev acc prior:", round(acc_p, 3), [(x, round(y, 3)) for x, y in acc_p_label.items()])
|
print("dev acc oracle:", round(acc_o, 3), oracle_by_label)
|
||||||
|
|
||||||
|
random_by_label = [(x, round(y, 3)) for x, y in acc_r_d.items()]
|
||||||
|
print("dev acc random:", round(acc_r, 3), random_by_label)
|
||||||
|
|
||||||
|
prior_by_label = [(x, round(y, 3)) for x, y in acc_p_d.items()]
|
||||||
|
print("dev acc prior:", round(acc_p, 3), prior_by_label)
|
||||||
|
|
||||||
# using only context
|
# using only context
|
||||||
el_pipe.cfg["context_weight"] = 1
|
el_pipe.cfg["context_weight"] = 1
|
||||||
el_pipe.cfg["prior_weight"] = 0
|
el_pipe.cfg["prior_weight"] = 0
|
||||||
dev_acc_context, dev_acc_context_dict = _measure_accuracy(dev_data, el_pipe)
|
dev_acc_context, dev_acc_cont_d = _measure_acc(dev_data, el_pipe)
|
||||||
print("dev acc context avg:", round(dev_acc_context, 3),
|
context_by_label = [(x, round(y, 3)) for x, y in dev_acc_cont_d.items()]
|
||||||
[(x, round(y, 3)) for x, y in dev_acc_context_dict.items()])
|
print("dev acc context avg:", round(dev_acc_context, 3), context_by_label)
|
||||||
|
|
||||||
# measuring combined accuracy (prior + context)
|
# measuring combined accuracy (prior + context)
|
||||||
el_pipe.cfg["context_weight"] = 1
|
el_pipe.cfg["context_weight"] = 1
|
||||||
el_pipe.cfg["prior_weight"] = 1
|
el_pipe.cfg["prior_weight"] = 1
|
||||||
dev_acc_combo, dev_acc_combo_dict = _measure_accuracy(dev_data, el_pipe, error_analysis=False)
|
dev_acc_combo, dev_acc_combo_d = _measure_acc(dev_data, el_pipe)
|
||||||
print("dev acc combo avg:", round(dev_acc_combo, 3),
|
combo_by_label = [(x, round(y, 3)) for x, y in dev_acc_combo_d.items()]
|
||||||
[(x, round(y, 3)) for x, y in dev_acc_combo_dict.items()])
|
print("dev acc combo avg:", round(dev_acc_combo, 3), combo_by_label)
|
||||||
|
|
||||||
# STEP 8: apply the EL pipe on a toy example
|
# STEP 8: apply the EL pipe on a toy example
|
||||||
if to_test_pipeline:
|
if to_test_pipeline:
|
||||||
print()
|
print()
|
||||||
print("STEP 8: applying Entity Linking to toy example", datetime.datetime.now())
|
print("STEP 8: applying Entity Linking to toy example", now())
|
||||||
print()
|
print()
|
||||||
run_el_toy_example(nlp=nlp_2)
|
run_el_toy_example(nlp=nlp_2)
|
||||||
|
|
||||||
# STEP 9: write the NLP pipeline (including entity linker) to file
|
# STEP 9: write the NLP pipeline (including entity linker) to file
|
||||||
if to_write_nlp:
|
if to_write_nlp:
|
||||||
print()
|
print()
|
||||||
print("STEP 9: testing NLP IO", datetime.datetime.now())
|
print("STEP 9: testing NLP IO", now())
|
||||||
print()
|
print()
|
||||||
print("writing to", NLP_2_DIR)
|
print("writing to", NLP_2_DIR)
|
||||||
nlp_2.to_disk(NLP_2_DIR)
|
nlp_2.to_disk(NLP_2_DIR)
|
||||||
|
@ -262,23 +294,26 @@ def run_pipeline():
|
||||||
el_pipe = nlp_3.get_pipe("entity_linker")
|
el_pipe = nlp_3.get_pipe("entity_linker")
|
||||||
|
|
||||||
dev_limit = 5000
|
dev_limit = 5000
|
||||||
dev_data = training_set_creator.read_training(nlp=nlp_2,
|
dev_data = training_set_creator.read_training(
|
||||||
|
nlp=nlp_2,
|
||||||
training_dir=TRAINING_DIR,
|
training_dir=TRAINING_DIR,
|
||||||
dev=True,
|
dev=True,
|
||||||
limit=dev_limit)
|
limit=dev_limit,
|
||||||
|
kb=el_pipe.kb,
|
||||||
|
)
|
||||||
|
|
||||||
print("Dev testing from file on", len(dev_data), "articles")
|
print("Dev testing from file on", len(dev_data), "articles")
|
||||||
print()
|
print()
|
||||||
|
|
||||||
dev_acc_combo, dev_acc_combo_dict = _measure_accuracy(dev_data, el_pipe=el_pipe, error_analysis=False)
|
dev_acc_combo, dev_acc_combo_dict = _measure_acc(dev_data, el_pipe)
|
||||||
print("dev acc combo avg:", round(dev_acc_combo, 3),
|
combo_by_label = [(x, round(y, 3)) for x, y in dev_acc_combo_dict.items()]
|
||||||
[(x, round(y, 3)) for x, y in dev_acc_combo_dict.items()])
|
print("dev acc combo avg:", round(dev_acc_combo, 3), combo_by_label)
|
||||||
|
|
||||||
print()
|
print()
|
||||||
print("STOP", datetime.datetime.now())
|
print("STOP", now())
|
||||||
|
|
||||||
|
|
||||||
def _measure_accuracy(data, el_pipe=None, error_analysis=False):
|
def _measure_acc(data, el_pipe=None, error_analysis=False):
|
||||||
# If the docs in the data require further processing with an entity linker, set el_pipe
|
# If the docs in the data require further processing with an entity linker, set el_pipe
|
||||||
correct_by_label = dict()
|
correct_by_label = dict()
|
||||||
incorrect_by_label = dict()
|
incorrect_by_label = dict()
|
||||||
|
@ -291,7 +326,9 @@ def _measure_accuracy(data, el_pipe=None, error_analysis=False):
|
||||||
for doc, gold in zip(docs, golds):
|
for doc, gold in zip(docs, golds):
|
||||||
try:
|
try:
|
||||||
correct_entries_per_article = dict()
|
correct_entries_per_article = dict()
|
||||||
for entity in gold.links:
|
for entity, value in gold.links.items():
|
||||||
|
# only evaluating on positive examples
|
||||||
|
if value:
|
||||||
start, end, gold_kb = entity
|
start, end, gold_kb = entity
|
||||||
correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb
|
correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb
|
||||||
|
|
||||||
|
@ -300,7 +337,8 @@ def _measure_accuracy(data, el_pipe=None, error_analysis=False):
|
||||||
pred_entity = ent.kb_id_
|
pred_entity = ent.kb_id_
|
||||||
start = ent.start_char
|
start = ent.start_char
|
||||||
end = ent.end_char
|
end = ent.end_char
|
||||||
gold_entity = correct_entries_per_article.get(str(start) + "-" + str(end), None)
|
offset = str(start) + "-" + str(end)
|
||||||
|
gold_entity = correct_entries_per_article.get(offset, None)
|
||||||
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
|
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
|
||||||
if gold_entity is not None:
|
if gold_entity is not None:
|
||||||
if gold_entity == pred_entity:
|
if gold_entity == pred_entity:
|
||||||
|
@ -311,7 +349,12 @@ def _measure_accuracy(data, el_pipe=None, error_analysis=False):
|
||||||
incorrect_by_label[ent_label] = incorrect + 1
|
incorrect_by_label[ent_label] = incorrect + 1
|
||||||
if error_analysis:
|
if error_analysis:
|
||||||
print(ent.text, "in", doc)
|
print(ent.text, "in", doc)
|
||||||
print("Predicted", pred_entity, "should have been", gold_entity)
|
print(
|
||||||
|
"Predicted",
|
||||||
|
pred_entity,
|
||||||
|
"should have been",
|
||||||
|
gold_entity,
|
||||||
|
)
|
||||||
print()
|
print()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -323,16 +366,16 @@ def _measure_accuracy(data, el_pipe=None, error_analysis=False):
|
||||||
|
|
||||||
def _measure_baselines(data, kb):
|
def _measure_baselines(data, kb):
|
||||||
# Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound
|
# Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound
|
||||||
counts_by_label = dict()
|
counts_d = dict()
|
||||||
|
|
||||||
random_correct_by_label = dict()
|
random_correct_d = dict()
|
||||||
random_incorrect_by_label = dict()
|
random_incorrect_d = dict()
|
||||||
|
|
||||||
oracle_correct_by_label = dict()
|
oracle_correct_d = dict()
|
||||||
oracle_incorrect_by_label = dict()
|
oracle_incorrect_d = dict()
|
||||||
|
|
||||||
prior_correct_by_label = dict()
|
prior_correct_d = dict()
|
||||||
prior_incorrect_by_label = dict()
|
prior_incorrect_d = dict()
|
||||||
|
|
||||||
docs = [d for d, g in data if len(d) > 0]
|
docs = [d for d, g in data if len(d) > 0]
|
||||||
golds = [g for d, g in data if len(d) > 0]
|
golds = [g for d, g in data if len(d) > 0]
|
||||||
|
@ -345,14 +388,15 @@ def _measure_baselines(data, kb):
|
||||||
correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb
|
correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb
|
||||||
|
|
||||||
for ent in doc.ents:
|
for ent in doc.ents:
|
||||||
ent_label = ent.label_
|
label = ent.label_
|
||||||
start = ent.start_char
|
start = ent.start_char
|
||||||
end = ent.end_char
|
end = ent.end_char
|
||||||
gold_entity = correct_entries_per_article.get(str(start) + "-" + str(end), None)
|
offset = str(start) + "-" + str(end)
|
||||||
|
gold_entity = correct_entries_per_article.get(offset, None)
|
||||||
|
|
||||||
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
|
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
|
||||||
if gold_entity is not None:
|
if gold_entity is not None:
|
||||||
counts_by_label[ent_label] = counts_by_label.get(ent_label, 0) + 1
|
counts_d[label] = counts_d.get(label, 0) + 1
|
||||||
candidates = kb.get_candidates(ent.text)
|
candidates = kb.get_candidates(ent.text)
|
||||||
oracle_candidate = ""
|
oracle_candidate = ""
|
||||||
best_candidate = ""
|
best_candidate = ""
|
||||||
|
@ -370,28 +414,36 @@ def _measure_baselines(data, kb):
|
||||||
random_candidate = random.choice(candidates).entity_
|
random_candidate = random.choice(candidates).entity_
|
||||||
|
|
||||||
if gold_entity == best_candidate:
|
if gold_entity == best_candidate:
|
||||||
prior_correct_by_label[ent_label] = prior_correct_by_label.get(ent_label, 0) + 1
|
prior_correct_d[label] = prior_correct_d.get(label, 0) + 1
|
||||||
else:
|
else:
|
||||||
prior_incorrect_by_label[ent_label] = prior_incorrect_by_label.get(ent_label, 0) + 1
|
prior_incorrect_d[label] = prior_incorrect_d.get(label, 0) + 1
|
||||||
|
|
||||||
if gold_entity == random_candidate:
|
if gold_entity == random_candidate:
|
||||||
random_correct_by_label[ent_label] = random_correct_by_label.get(ent_label, 0) + 1
|
random_correct_d[label] = random_correct_d.get(label, 0) + 1
|
||||||
else:
|
else:
|
||||||
random_incorrect_by_label[ent_label] = random_incorrect_by_label.get(ent_label, 0) + 1
|
random_incorrect_d[label] = random_incorrect_d.get(label, 0) + 1
|
||||||
|
|
||||||
if gold_entity == oracle_candidate:
|
if gold_entity == oracle_candidate:
|
||||||
oracle_correct_by_label[ent_label] = oracle_correct_by_label.get(ent_label, 0) + 1
|
oracle_correct_d[label] = oracle_correct_d.get(label, 0) + 1
|
||||||
else:
|
else:
|
||||||
oracle_incorrect_by_label[ent_label] = oracle_incorrect_by_label.get(ent_label, 0) + 1
|
oracle_incorrect_d[label] = oracle_incorrect_d.get(label, 0) + 1
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Error assessing accuracy", e)
|
print("Error assessing accuracy", e)
|
||||||
|
|
||||||
acc_prior, acc_prior_by_label = calculate_acc(prior_correct_by_label, prior_incorrect_by_label)
|
acc_prior, acc_prior_d = calculate_acc(prior_correct_d, prior_incorrect_d)
|
||||||
acc_rand, acc_rand_by_label = calculate_acc(random_correct_by_label, random_incorrect_by_label)
|
acc_rand, acc_rand_d = calculate_acc(random_correct_d, random_incorrect_d)
|
||||||
acc_oracle, acc_oracle_by_label = calculate_acc(oracle_correct_by_label, oracle_incorrect_by_label)
|
acc_oracle, acc_oracle_d = calculate_acc(oracle_correct_d, oracle_incorrect_d)
|
||||||
|
|
||||||
return counts_by_label, acc_rand, acc_rand_by_label, acc_prior, acc_prior_by_label, acc_oracle, acc_oracle_by_label
|
return (
|
||||||
|
counts_d,
|
||||||
|
acc_rand,
|
||||||
|
acc_rand_d,
|
||||||
|
acc_prior,
|
||||||
|
acc_prior_d,
|
||||||
|
acc_oracle,
|
||||||
|
acc_oracle_d,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def calculate_acc(correct_by_label, incorrect_by_label):
|
def calculate_acc(correct_by_label, incorrect_by_label):
|
||||||
|
@ -422,15 +474,23 @@ def check_kb(kb):
|
||||||
|
|
||||||
print("generating candidates for " + mention + " :")
|
print("generating candidates for " + mention + " :")
|
||||||
for c in candidates:
|
for c in candidates:
|
||||||
print(" ", c.prior_prob, c.alias_, "-->", c.entity_ + " (freq=" + str(c.entity_freq) + ")")
|
print(
|
||||||
|
" ",
|
||||||
|
c.prior_prob,
|
||||||
|
c.alias_,
|
||||||
|
"-->",
|
||||||
|
c.entity_ + " (freq=" + str(c.entity_freq) + ")",
|
||||||
|
)
|
||||||
print()
|
print()
|
||||||
|
|
||||||
|
|
||||||
def run_el_toy_example(nlp):
|
def run_el_toy_example(nlp):
|
||||||
text = "In The Hitchhiker's Guide to the Galaxy, written by Douglas Adams, " \
|
text = (
|
||||||
"Douglas reminds us to always bring our towel, even in China or Brazil. " \
|
"In The Hitchhiker's Guide to the Galaxy, written by Douglas Adams, "
|
||||||
"The main character in Doug's novel is the man Arthur Dent, " \
|
"Douglas reminds us to always bring our towel, even in China or Brazil. "
|
||||||
|
"The main character in Doug's novel is the man Arthur Dent, "
|
||||||
"but Douglas doesn't write about George Washington or Homer Simpson."
|
"but Douglas doesn't write about George Washington or Homer Simpson."
|
||||||
|
)
|
||||||
doc = nlp(text)
|
doc = nlp(text)
|
||||||
print(text)
|
print(text)
|
||||||
for ent in doc.ents:
|
for ent in doc.ents:
|
||||||
|
|
|
@ -208,7 +208,7 @@ cdef class KnowledgeBase:
|
||||||
|
|
||||||
# Return an empty list if this entity is unknown in this KB
|
# Return an empty list if this entity is unknown in this KB
|
||||||
if entity_hash not in self._entry_index:
|
if entity_hash not in self._entry_index:
|
||||||
return []
|
return [0] * self.entity_vector_length
|
||||||
entry_index = self._entry_index[entity_hash]
|
entry_index = self._entry_index[entity_hash]
|
||||||
|
|
||||||
return self._vectors_table[self._entries[entry_index].vector_index]
|
return self._vectors_table[self._entries[entry_index].vector_index]
|
||||||
|
|
|
@ -1151,9 +1151,11 @@ class EntityLinker(Pipe):
|
||||||
ents_by_offset = dict()
|
ents_by_offset = dict()
|
||||||
for ent in doc.ents:
|
for ent in doc.ents:
|
||||||
ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)] = ent
|
ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)] = ent
|
||||||
for entity in gold.links:
|
for entity, value in gold.links.items():
|
||||||
start, end, gold_kb = entity
|
start, end, kb_id = entity
|
||||||
mention = doc.text[start:end]
|
mention = doc.text[start:end]
|
||||||
|
entity_encoding = self.kb.get_vector(kb_id)
|
||||||
|
prior_prob = self.kb.get_prior_prob(kb_id, mention)
|
||||||
|
|
||||||
gold_ent = ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)]
|
gold_ent = ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)]
|
||||||
assert gold_ent is not None
|
assert gold_ent is not None
|
||||||
|
@ -1161,24 +1163,17 @@ class EntityLinker(Pipe):
|
||||||
if len(type_to_int) > 0:
|
if len(type_to_int) > 0:
|
||||||
type_vector[type_to_int[gold_ent.label_]] = 1
|
type_vector[type_to_int[gold_ent.label_]] = 1
|
||||||
|
|
||||||
candidates = self.kb.get_candidates(mention)
|
# store data
|
||||||
random.shuffle(candidates)
|
|
||||||
for c in candidates:
|
|
||||||
kb_id = c.entity_
|
|
||||||
entity_encoding = c.entity_vector
|
|
||||||
entity_encodings.append(entity_encoding)
|
entity_encodings.append(entity_encoding)
|
||||||
context_docs.append(doc)
|
context_docs.append(doc)
|
||||||
type_vectors.append(type_vector)
|
type_vectors.append(type_vector)
|
||||||
|
|
||||||
if self.cfg.get("prior_weight", 1) > 0:
|
if self.cfg.get("prior_weight", 1) > 0:
|
||||||
priors.append([c.prior_prob])
|
priors.append([prior_prob])
|
||||||
else:
|
else:
|
||||||
priors.append([0])
|
priors.append([0])
|
||||||
|
|
||||||
if kb_id == gold_kb:
|
cats.append([value])
|
||||||
cats.append([1])
|
|
||||||
else:
|
|
||||||
cats.append([0])
|
|
||||||
|
|
||||||
if len(entity_encodings) > 0:
|
if len(entity_encodings) > 0:
|
||||||
assert len(priors) == len(entity_encodings) == len(context_docs) == len(cats) == len(type_vectors)
|
assert len(priors) == len(entity_encodings) == len(context_docs) == len(cats) == len(type_vectors)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user