mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 09:14:32 +03:00
clean up code, remove old code, move to bin
This commit is contained in:
parent
ffae7d3555
commit
0d177c1146
0
bin/wiki_entity_linking/__init__.py
Normal file
0
bin/wiki_entity_linking/__init__.py
Normal file
|
@ -1,15 +1,13 @@
|
|||
# coding: utf-8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import spacy
|
||||
from examples.pipeline.wiki_entity_linking.train_descriptions import EntityEncoder
|
||||
from bin.wiki_entity_linking.train_descriptions import EntityEncoder
|
||||
from spacy.kb import KnowledgeBase
|
||||
|
||||
import csv
|
||||
import datetime
|
||||
|
||||
from . import wikipedia_processor as wp
|
||||
from . import wikidata_processor as wd
|
||||
from bin.wiki_entity_linking import wikidata_processor as wd, wikipedia_processor as wp
|
||||
|
||||
INPUT_DIM = 300 # dimension of pre-trained vectors
|
||||
DESC_WIDTH = 64
|
||||
|
@ -34,7 +32,7 @@ def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ,
|
|||
|
||||
else:
|
||||
# read the mappings from file
|
||||
title_to_id = _get_entity_to_id(entity_def_output)
|
||||
title_to_id = get_entity_to_id(entity_def_output)
|
||||
id_to_descr = _get_id_to_description(entity_descr_output)
|
||||
|
||||
print()
|
||||
|
@ -56,7 +54,8 @@ def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ,
|
|||
frequency_list.append(freq)
|
||||
filtered_title_to_id[title] = entity
|
||||
|
||||
print("Kept", len(filtered_title_to_id.keys()), "out of", len(title_to_id.keys()), "titles with filter frequency", min_entity_freq)
|
||||
print("Kept", len(filtered_title_to_id.keys()), "out of", len(title_to_id.keys()),
|
||||
"titles with filter frequency", min_entity_freq)
|
||||
|
||||
print()
|
||||
print(" * train entity encoder", datetime.datetime.now())
|
||||
|
@ -101,7 +100,7 @@ def _write_entity_files(entity_def_output, entity_descr_output, title_to_id, id_
|
|||
descr_file.write(str(qid) + "|" + descr + "\n")
|
||||
|
||||
|
||||
def _get_entity_to_id(entity_def_output):
|
||||
def get_entity_to_id(entity_def_output):
|
||||
entity_to_id = dict()
|
||||
with open(entity_def_output, 'r', encoding='utf8') as csvfile:
|
||||
csvreader = csv.reader(csvfile, delimiter='|')
|
|
@ -55,8 +55,6 @@ class EntityEncoder:
|
|||
print("Trained on", processed, "entities across", self.EPOCHS, "epochs")
|
||||
print("Final loss:", loss)
|
||||
|
||||
# self._test_encoder()
|
||||
|
||||
def _train_model(self, description_list):
|
||||
# TODO: when loss gets too low, a 'mean of empty slice' warning is thrown by numpy
|
||||
|
||||
|
@ -123,40 +121,3 @@ class EntityEncoder:
|
|||
def _get_loss(golds, scores):
|
||||
loss, gradients = get_cossim_loss(scores, golds)
|
||||
return loss, gradients
|
||||
|
||||
def _test_encoder(self):
|
||||
# Test encoder on some dummy examples
|
||||
desc_A1 = "Fictional character in The Simpsons"
|
||||
desc_A2 = "Simpsons - fictional human"
|
||||
desc_A3 = "Fictional character in The Flintstones"
|
||||
desc_A4 = "Politician from the US"
|
||||
|
||||
A1_doc_vector = np.asarray([self._get_doc_embedding(self.nlp(desc_A1))])
|
||||
A2_doc_vector = np.asarray([self._get_doc_embedding(self.nlp(desc_A2))])
|
||||
A3_doc_vector = np.asarray([self._get_doc_embedding(self.nlp(desc_A3))])
|
||||
A4_doc_vector = np.asarray([self._get_doc_embedding(self.nlp(desc_A4))])
|
||||
|
||||
loss_a1_a1, _ = get_cossim_loss(A1_doc_vector, A1_doc_vector)
|
||||
loss_a1_a2, _ = get_cossim_loss(A1_doc_vector, A2_doc_vector)
|
||||
loss_a1_a3, _ = get_cossim_loss(A1_doc_vector, A3_doc_vector)
|
||||
loss_a1_a4, _ = get_cossim_loss(A1_doc_vector, A4_doc_vector)
|
||||
|
||||
print("sim doc A1 A1", loss_a1_a1)
|
||||
print("sim doc A1 A2", loss_a1_a2)
|
||||
print("sim doc A1 A3", loss_a1_a3)
|
||||
print("sim doc A1 A4", loss_a1_a4)
|
||||
|
||||
A1_encoded = self.encoder(A1_doc_vector)
|
||||
A2_encoded = self.encoder(A2_doc_vector)
|
||||
A3_encoded = self.encoder(A3_doc_vector)
|
||||
A4_encoded = self.encoder(A4_doc_vector)
|
||||
|
||||
loss_a1_a1, _ = get_cossim_loss(A1_encoded, A1_encoded)
|
||||
loss_a1_a2, _ = get_cossim_loss(A1_encoded, A2_encoded)
|
||||
loss_a1_a3, _ = get_cossim_loss(A1_encoded, A3_encoded)
|
||||
loss_a1_a4, _ = get_cossim_loss(A1_encoded, A4_encoded)
|
||||
|
||||
print("sim encoded A1 A1", loss_a1_a1)
|
||||
print("sim encoded A1 A2", loss_a1_a2)
|
||||
print("sim encoded A1 A3", loss_a1_a3)
|
||||
print("sim encoded A1 A4", loss_a1_a4)
|
|
@ -7,7 +7,7 @@ import bz2
|
|||
import datetime
|
||||
|
||||
from spacy.gold import GoldParse
|
||||
from . import wikipedia_processor as wp, kb_creator
|
||||
from bin.wiki_entity_linking import kb_creator, wikipedia_processor as wp
|
||||
|
||||
"""
|
||||
Process Wikipedia interlinks to generate a training dataset for the EL algorithm
|
||||
|
@ -18,7 +18,7 @@ ENTITY_FILE = "gold_entities_1000000.csv" # use this file for faster processin
|
|||
|
||||
|
||||
def create_training(entity_def_input, training_output):
|
||||
wp_to_id = kb_creator._get_entity_to_id(entity_def_input)
|
||||
wp_to_id = kb_creator.get_entity_to_id(entity_def_input)
|
||||
_process_wikipedia_texts(wp_to_id, training_output, limit=None)
|
||||
|
||||
|
||||
|
@ -71,7 +71,8 @@ def _process_wikipedia_texts(wp_to_id, training_output, limit=None):
|
|||
elif clean_line == "</page>":
|
||||
if article_id:
|
||||
try:
|
||||
_process_wp_text(wp_to_id, entityfile, article_id, article_title, article_text.strip(), training_output)
|
||||
_process_wp_text(wp_to_id, entityfile, article_id, article_title, article_text.strip(),
|
||||
training_output)
|
||||
except Exception as e:
|
||||
print("Error processing article", article_id, article_title, e)
|
||||
else:
|
|
@ -13,9 +13,12 @@ def read_wikidata_entities_json(limit=None, to_print=False):
|
|||
""" Read the JSON wiki data and parse out the entities. Takes about 7u30 to parse 55M lines. """
|
||||
|
||||
lang = 'en'
|
||||
# prop_filter = {'P31': {'Q5', 'Q15632617'}} # currently defined as OR: one property suffices to be selected
|
||||
site_filter = 'enwiki'
|
||||
|
||||
# filter currently disabled to get ALL data
|
||||
prop_filter = dict()
|
||||
# prop_filter = {'P31': {'Q5', 'Q15632617'}} # currently defined as OR: one property suffices to be selected
|
||||
|
||||
title_to_id = dict()
|
||||
id_to_descr = dict()
|
||||
|
||||
|
@ -25,6 +28,7 @@ def read_wikidata_entities_json(limit=None, to_print=False):
|
|||
parse_labels = False
|
||||
parse_descriptions = True
|
||||
parse_aliases = False
|
||||
parse_claims = False
|
||||
|
||||
with bz2.open(WIKIDATA_JSON, mode='rb') as file:
|
||||
line = file.readline()
|
||||
|
@ -45,14 +49,15 @@ def read_wikidata_entities_json(limit=None, to_print=False):
|
|||
keep = True
|
||||
|
||||
claims = obj["claims"]
|
||||
# for prop, value_set in prop_filter.items():
|
||||
# claim_property = claims.get(prop, None)
|
||||
# if claim_property:
|
||||
# for cp in claim_property:
|
||||
# cp_id = cp['mainsnak'].get('datavalue', {}).get('value', {}).get('id')
|
||||
# cp_rank = cp['rank']
|
||||
# if cp_rank != "deprecated" and cp_id in value_set:
|
||||
# keep = True
|
||||
if parse_claims:
|
||||
for prop, value_set in prop_filter.items():
|
||||
claim_property = claims.get(prop, None)
|
||||
if claim_property:
|
||||
for cp in claim_property:
|
||||
cp_id = cp['mainsnak'].get('datavalue', {}).get('value', {}).get('id')
|
||||
cp_rank = cp['rank']
|
||||
if cp_rank != "deprecated" and cp_id in value_set:
|
||||
keep = True
|
||||
|
||||
if keep:
|
||||
unique_id = obj["id"]
|
||||
|
@ -64,8 +69,10 @@ def read_wikidata_entities_json(limit=None, to_print=False):
|
|||
# parsing all properties that refer to other entities
|
||||
if parse_properties:
|
||||
for prop, claim_property in claims.items():
|
||||
cp_dicts = [cp['mainsnak']['datavalue'].get('value') for cp in claim_property if cp['mainsnak'].get('datavalue')]
|
||||
cp_values = [cp_dict.get('id') for cp_dict in cp_dicts if isinstance(cp_dict, dict) if cp_dict.get('id') is not None]
|
||||
cp_dicts = [cp['mainsnak']['datavalue'].get('value') for cp in claim_property
|
||||
if cp['mainsnak'].get('datavalue')]
|
||||
cp_values = [cp_dict.get('id') for cp_dict in cp_dicts if isinstance(cp_dict, dict)
|
||||
if cp_dict.get('id') is not None]
|
||||
if cp_values:
|
||||
if to_print:
|
||||
print("prop:", prop, cp_values)
|
||||
|
@ -104,7 +111,7 @@ def read_wikidata_entities_json(limit=None, to_print=False):
|
|||
if lang_aliases:
|
||||
for item in lang_aliases:
|
||||
if to_print:
|
||||
print("alias (" + lang + "):", item["value"])
|
||||
print("alias (" + lang + "):", item["value"])
|
||||
|
||||
if to_print:
|
||||
print()
|
|
@ -26,8 +26,8 @@ wiki_namespaces = ["b", "betawikiversity", "Book", "c", "Category", "Commons",
|
|||
"mw", "n", "nost", "oldwikisource", "outreach", "outreachwiki", "otrs", "OTRSwiki",
|
||||
"Portal", "phab", "Phabricator", "Project", "q", "quality", "rev",
|
||||
"s", "spcom", "Special", "species", "Strategy", "sulutil", "svn",
|
||||
"Talk", "Template", "Template talk", "Testwiki", "ticket", "TimedText", "Toollabs", "tools", "tswiki",
|
||||
"User", "User talk", "v", "voy",
|
||||
"Talk", "Template", "Template talk", "Testwiki", "ticket", "TimedText", "Toollabs", "tools",
|
||||
"tswiki", "User", "User talk", "v", "voy",
|
||||
"w", "Wikibooks", "Wikidata", "wikiHow", "Wikinvest", "wikilivres", "Wikimedia", "Wikinews",
|
||||
"Wikipedia", "Wikipedia talk", "Wikiquote", "Wikisource", "Wikispecies", "Wikitech",
|
||||
"Wikiversity", "Wikivoyage", "wikt", "wiktionary", "wmf", "wmania", "WP"]
|
|
@ -1,136 +0,0 @@
|
|||
# coding: utf-8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import os
|
||||
import spacy
|
||||
import datetime
|
||||
from os import listdir
|
||||
|
||||
from examples.pipeline.wiki_entity_linking import training_set_creator
|
||||
|
||||
# requires: pip install neuralcoref --no-binary neuralcoref
|
||||
# import neuralcoref
|
||||
|
||||
|
||||
def run_kb_toy_example(kb):
|
||||
for mention in ("Bush", "Douglas Adams", "Homer"):
|
||||
candidates = kb.get_candidates(mention)
|
||||
|
||||
print("generating candidates for " + mention + " :")
|
||||
for c in candidates:
|
||||
print(" ", c.prior_prob, c.alias_, "-->", c.entity_ + " (freq=" + str(c.entity_freq) + ")")
|
||||
print()
|
||||
|
||||
|
||||
|
||||
|
||||
def run_el_dev(nlp, kb, training_dir, limit=None):
|
||||
correct_entries_per_article, _ = training_set_creator.read_training_entities(training_output=training_dir)
|
||||
|
||||
predictions = list()
|
||||
golds = list()
|
||||
|
||||
cnt = 0
|
||||
for f in listdir(training_dir):
|
||||
if not limit or cnt < limit:
|
||||
if is_dev(f):
|
||||
article_id = f.replace(".txt", "")
|
||||
if cnt % 500 == 0:
|
||||
print(datetime.datetime.now(), "processed", cnt, "files in the dev dataset")
|
||||
cnt += 1
|
||||
with open(os.path.join(training_dir, f), mode="r", encoding='utf8') as file:
|
||||
text = file.read()
|
||||
doc = nlp(text)
|
||||
for ent in doc.ents:
|
||||
if ent.label_ == "PERSON": # TODO: expand to other types
|
||||
gold_entity = correct_entries_per_article[article_id].get(ent.text, None)
|
||||
# only evaluating gold entities we know, because the training data is not complete
|
||||
if gold_entity:
|
||||
predictions.append(ent.kb_id_)
|
||||
golds.append(gold_entity)
|
||||
|
||||
print("Processed", cnt, "dev articles")
|
||||
print()
|
||||
evaluate(predictions, golds)
|
||||
|
||||
|
||||
def is_dev(file_name):
|
||||
return file_name.endswith("3.txt")
|
||||
|
||||
|
||||
def evaluate(predictions, golds, to_print=True, times_hundred=True):
|
||||
if len(predictions) != len(golds):
|
||||
raise ValueError("predictions and gold entities should have the same length")
|
||||
|
||||
tp = 0
|
||||
fp = 0
|
||||
fn = 0
|
||||
|
||||
corrects = 0
|
||||
incorrects = 0
|
||||
|
||||
for pred, gold in zip(predictions, golds):
|
||||
is_correct = pred == gold
|
||||
if is_correct:
|
||||
corrects += 1
|
||||
else:
|
||||
incorrects += 1
|
||||
if not pred:
|
||||
if not is_correct: # we don't care about tn
|
||||
fn += 1
|
||||
elif is_correct:
|
||||
tp += 1
|
||||
else:
|
||||
fp += 1
|
||||
|
||||
if to_print:
|
||||
print("Evaluating", len(golds), "entities")
|
||||
print("tp", tp)
|
||||
print("fp", fp)
|
||||
print("fn", fn)
|
||||
|
||||
precision = tp / (tp + fp + 0.0000001)
|
||||
recall = tp / (tp + fn + 0.0000001)
|
||||
if times_hundred:
|
||||
precision = precision*100
|
||||
recall = recall*100
|
||||
fscore = 2 * recall * precision / (recall + precision + 0.0000001)
|
||||
|
||||
accuracy = corrects / (corrects + incorrects)
|
||||
|
||||
if to_print:
|
||||
print("precision", round(precision, 1), "%")
|
||||
print("recall", round(recall, 1), "%")
|
||||
print("Fscore", round(fscore, 1), "%")
|
||||
print("Accuracy", round(accuracy, 1), "%")
|
||||
|
||||
return precision, recall, fscore, accuracy
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# TODO
|
||||
def add_coref(nlp):
|
||||
""" Add coreference resolution to our model """
|
||||
# TODO: this doesn't work yet
|
||||
# neuralcoref.add_to_pipe(nlp)
|
||||
print("done adding to pipe")
|
||||
|
||||
doc = nlp(u'My sister has a dog. She loves him.')
|
||||
print("done doc")
|
||||
|
||||
print(doc._.has_coref)
|
||||
print(doc._.coref_clusters)
|
||||
|
||||
|
||||
# TODO
|
||||
def _run_ner_depr(nlp, clean_text, article_dict):
|
||||
doc = nlp(clean_text)
|
||||
for ent in doc.ents:
|
||||
if ent.label_ == "PERSON": # TODO: expand to non-persons
|
||||
ent_id = article_dict.get(ent.text)
|
||||
if ent_id:
|
||||
print(" -", ent.text, ent.label_, ent_id)
|
||||
else:
|
||||
print(" -", ent.text, ent.label_, '???') # TODO: investigate these cases
|
|
@ -1,490 +0,0 @@
|
|||
# coding: utf-8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import os
|
||||
import datetime
|
||||
from os import listdir
|
||||
import numpy as np
|
||||
import random
|
||||
from random import shuffle
|
||||
from thinc.neural._classes.convolution import ExtractWindow
|
||||
from thinc.neural.util import get_array_module
|
||||
|
||||
from examples.pipeline.wiki_entity_linking import run_el, training_set_creator, kb_creator
|
||||
|
||||
from spacy._ml import SpacyVectors, create_default_optimizer, zero_init, cosine
|
||||
|
||||
from thinc.api import chain, concatenate, flatten_add_lengths, clone, with_flatten
|
||||
from thinc.v2v import Model, Maxout, Affine
|
||||
from thinc.t2v import Pooling, mean_pool
|
||||
from thinc.t2t import ParametricAttention
|
||||
from thinc.misc import Residual
|
||||
from thinc.misc import LayerNorm as LN
|
||||
|
||||
# from spacy.cli.pretrain import get_cossim_loss
|
||||
from spacy.matcher import PhraseMatcher
|
||||
|
||||
|
||||
class EL_Model:
|
||||
|
||||
PRINT_INSPECT = False
|
||||
PRINT_BATCH_LOSS = False
|
||||
EPS = 0.0000000005
|
||||
|
||||
BATCH_SIZE = 100
|
||||
|
||||
DOC_CUTOFF = 300 # number of characters from the doc context
|
||||
INPUT_DIM = 300 # dimension of pre-trained vectors
|
||||
|
||||
HIDDEN_1_WIDTH = 32
|
||||
DESC_WIDTH = 64
|
||||
ARTICLE_WIDTH = 128
|
||||
SENT_WIDTH = 64
|
||||
|
||||
DROP = 0.4
|
||||
LEARN_RATE = 0.005
|
||||
EPOCHS = 10
|
||||
L2 = 1e-6
|
||||
|
||||
name = "entity_linker"
|
||||
|
||||
def __init__(self, kb, nlp):
|
||||
run_el._prepare_pipeline(nlp, kb)
|
||||
self.nlp = nlp
|
||||
self.kb = kb
|
||||
|
||||
self._build_cnn(embed_width=self.INPUT_DIM,
|
||||
desc_width=self.DESC_WIDTH,
|
||||
article_width=self.ARTICLE_WIDTH,
|
||||
sent_width=self.SENT_WIDTH,
|
||||
hidden_1_width=self.HIDDEN_1_WIDTH)
|
||||
|
||||
def train_model(self, training_dir, entity_descr_output, trainlimit=None, devlimit=None, to_print=True):
|
||||
np.seterr(divide="raise", over="warn", under="ignore", invalid="raise")
|
||||
|
||||
id_to_descr = kb_creator._get_id_to_description(entity_descr_output)
|
||||
|
||||
train_ent, train_gold, train_desc, train_art, train_art_texts, train_sent, train_sent_texts = \
|
||||
self._get_training_data(training_dir, id_to_descr, False, trainlimit, to_print=False)
|
||||
train_clusters = list(train_ent.keys())
|
||||
|
||||
dev_ent, dev_gold, dev_desc, dev_art, dev_art_texts, dev_sent, dev_sent_texts = \
|
||||
self._get_training_data(training_dir, id_to_descr, True, devlimit, to_print=False)
|
||||
dev_clusters = list(dev_ent.keys())
|
||||
|
||||
dev_pos_count = len([g for g in dev_gold.values() if g])
|
||||
dev_neg_count = len([g for g in dev_gold.values() if not g])
|
||||
|
||||
# inspect data
|
||||
if self.PRINT_INSPECT:
|
||||
for cluster, entities in train_ent.items():
|
||||
print()
|
||||
for entity in entities:
|
||||
print("entity", entity)
|
||||
print("gold", train_gold[entity])
|
||||
print("desc", train_desc[entity])
|
||||
print("sentence ID", train_sent[entity])
|
||||
print("sentence text", train_sent_texts[train_sent[entity]])
|
||||
print("article ID", train_art[entity])
|
||||
print("article text", train_art_texts[train_art[entity]])
|
||||
print()
|
||||
|
||||
train_pos_entities = [k for k, v in train_gold.items() if v]
|
||||
train_neg_entities = [k for k, v in train_gold.items() if not v]
|
||||
|
||||
train_pos_count = len(train_pos_entities)
|
||||
train_neg_count = len(train_neg_entities)
|
||||
|
||||
self._begin_training()
|
||||
|
||||
if to_print:
|
||||
print()
|
||||
print("Training on", len(train_clusters), "entity clusters in", len(train_art_texts), "articles")
|
||||
print("Training instances pos/neg:", train_pos_count, train_neg_count)
|
||||
print()
|
||||
print("Dev test on", len(dev_clusters), "entity clusters in", len(dev_art_texts), "articles")
|
||||
print("Dev instances pos/neg:", dev_pos_count, dev_neg_count)
|
||||
print()
|
||||
print(" DOC_CUTOFF", self.DOC_CUTOFF)
|
||||
print(" INPUT_DIM", self.INPUT_DIM)
|
||||
print(" HIDDEN_1_WIDTH", self.HIDDEN_1_WIDTH)
|
||||
print(" DESC_WIDTH", self.DESC_WIDTH)
|
||||
print(" ARTICLE_WIDTH", self.ARTICLE_WIDTH)
|
||||
print(" SENT_WIDTH", self.SENT_WIDTH)
|
||||
print(" DROP", self.DROP)
|
||||
print(" LEARNING RATE", self.LEARN_RATE)
|
||||
print(" BATCH SIZE", self.BATCH_SIZE)
|
||||
print()
|
||||
|
||||
dev_random = self._test_dev(dev_ent, dev_gold, dev_desc, dev_art, dev_art_texts, dev_sent, dev_sent_texts,
|
||||
calc_random=True)
|
||||
print("acc", "dev_random", round(dev_random, 2))
|
||||
|
||||
dev_pre = self._test_dev(dev_ent, dev_gold, dev_desc, dev_art, dev_art_texts, dev_sent, dev_sent_texts,
|
||||
avg=True)
|
||||
print("acc", "dev_pre", round(dev_pre, 2))
|
||||
print()
|
||||
|
||||
processed = 0
|
||||
for i in range(self.EPOCHS):
|
||||
shuffle(train_clusters)
|
||||
|
||||
start = 0
|
||||
stop = min(self.BATCH_SIZE, len(train_clusters))
|
||||
|
||||
while start < len(train_clusters):
|
||||
next_batch = {c: train_ent[c] for c in train_clusters[start:stop]}
|
||||
processed += len(next_batch.keys())
|
||||
|
||||
self.update(entity_clusters=next_batch, golds=train_gold, descs=train_desc,
|
||||
art_texts=train_art_texts, arts=train_art,
|
||||
sent_texts=train_sent_texts, sents=train_sent)
|
||||
|
||||
start = start + self.BATCH_SIZE
|
||||
stop = min(stop + self.BATCH_SIZE, len(train_clusters))
|
||||
|
||||
train_acc = self._test_dev(train_ent, train_gold, train_desc, train_art, train_art_texts, train_sent, train_sent_texts, avg=True)
|
||||
dev_acc = self._test_dev(dev_ent, dev_gold, dev_desc, dev_art, dev_art_texts, dev_sent, dev_sent_texts, avg=True)
|
||||
|
||||
print(i, "acc train/dev", round(train_acc, 2), round(dev_acc, 2))
|
||||
|
||||
if to_print:
|
||||
print()
|
||||
print("Trained on", processed, "entity clusters across", self.EPOCHS, "epochs")
|
||||
|
||||
def _test_dev(self, entity_clusters, golds, descs, arts, art_texts, sents, sent_texts, avg=True, calc_random=False):
|
||||
correct = 0
|
||||
incorrect = 0
|
||||
|
||||
if calc_random:
|
||||
for cluster, entities in entity_clusters.items():
|
||||
correct_entities = [e for e in entities if golds[e]]
|
||||
assert len(correct_entities) == 1
|
||||
|
||||
entities = list(entities)
|
||||
shuffle(entities)
|
||||
|
||||
if calc_random:
|
||||
predicted_entity = random.choice(entities)
|
||||
if predicted_entity in correct_entities:
|
||||
correct += 1
|
||||
else:
|
||||
incorrect += 1
|
||||
|
||||
else:
|
||||
all_clusters = list()
|
||||
arts_list = list()
|
||||
sents_list = list()
|
||||
|
||||
for cluster in entity_clusters.keys():
|
||||
all_clusters.append(cluster)
|
||||
arts_list.append(art_texts[arts[cluster]])
|
||||
sents_list.append(sent_texts[sents[cluster]])
|
||||
|
||||
art_docs = list(self.nlp.pipe(arts_list))
|
||||
sent_docs = list(self.nlp.pipe(sents_list))
|
||||
|
||||
for i, cluster in enumerate(all_clusters):
|
||||
entities = entity_clusters[cluster]
|
||||
correct_entities = [e for e in entities if golds[e]]
|
||||
assert len(correct_entities) == 1
|
||||
|
||||
entities = list(entities)
|
||||
shuffle(entities)
|
||||
|
||||
desc_docs = self.nlp.pipe([descs[e] for e in entities])
|
||||
sent_doc = sent_docs[i]
|
||||
article_doc = art_docs[i]
|
||||
|
||||
predicted_index = self._predict(article_doc=article_doc, sent_doc=sent_doc,
|
||||
desc_docs=desc_docs, avg=avg)
|
||||
if entities[predicted_index] in correct_entities:
|
||||
correct += 1
|
||||
else:
|
||||
incorrect += 1
|
||||
|
||||
if correct == incorrect == 0:
|
||||
return 0
|
||||
|
||||
acc = correct / (correct + incorrect)
|
||||
return acc
|
||||
|
||||
def _predict(self, article_doc, sent_doc, desc_docs, avg=True, apply_threshold=True):
|
||||
if avg:
|
||||
with self.article_encoder.use_params(self.sgd_article.averages) \
|
||||
and self.desc_encoder.use_params(self.sgd_desc.averages)\
|
||||
and self.sent_encoder.use_params(self.sgd_sent.averages):
|
||||
desc_encodings = self.desc_encoder(desc_docs)
|
||||
doc_encoding = self.article_encoder([article_doc])
|
||||
sent_encoding = self.sent_encoder([sent_doc])
|
||||
|
||||
else:
|
||||
desc_encodings = self.desc_encoder(desc_docs)
|
||||
doc_encoding = self.article_encoder([article_doc])
|
||||
sent_encoding = self.sent_encoder([sent_doc])
|
||||
|
||||
concat_encoding = [list(doc_encoding[0]) + list(sent_encoding[0])]
|
||||
|
||||
if avg:
|
||||
with self.cont_encoder.use_params(self.sgd_cont.averages):
|
||||
cont_encodings = self.cont_encoder(np.asarray([concat_encoding[0]]))
|
||||
|
||||
else:
|
||||
cont_encodings = self.cont_encoder(np.asarray([concat_encoding[0]]))
|
||||
|
||||
context_enc = np.transpose(cont_encodings)
|
||||
|
||||
highest_sim = -5
|
||||
best_i = -1
|
||||
for i, desc_enc in enumerate(desc_encodings):
|
||||
sim = cosine(desc_enc, context_enc)
|
||||
if sim >= highest_sim:
|
||||
best_i = i
|
||||
highest_sim = sim
|
||||
|
||||
return best_i
|
||||
|
||||
def _build_cnn(self, embed_width, desc_width, article_width, sent_width, hidden_1_width):
|
||||
self.desc_encoder = self._encoder(in_width=embed_width, hidden_with=hidden_1_width, end_width=desc_width)
|
||||
self.cont_encoder = self._context_encoder(embed_width=embed_width, article_width=article_width,
|
||||
sent_width=sent_width, hidden_width=hidden_1_width,
|
||||
end_width=desc_width)
|
||||
|
||||
|
||||
# def _encoder(self, width):
|
||||
# tok2vec = Tok2Vec(width=width, embed_size=2000, pretrained_vectors=self.nlp.vocab.vectors.name, cnn_maxout_pieces=3,
|
||||
# subword_features=False, conv_depth=4, bilstm_depth=0)
|
||||
#
|
||||
# return tok2vec >> flatten_add_lengths >> Pooling(mean_pool)
|
||||
|
||||
def _context_encoder(self, embed_width, article_width, sent_width, hidden_width, end_width):
|
||||
self.article_encoder = self._encoder(in_width=embed_width, hidden_with=hidden_width, end_width=article_width)
|
||||
self.sent_encoder = self._encoder(in_width=embed_width, hidden_with=hidden_width, end_width=sent_width)
|
||||
|
||||
model = Affine(end_width, article_width+sent_width, drop_factor=0.0)
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _encoder(in_width, hidden_with, end_width):
|
||||
conv_depth = 2
|
||||
cnn_maxout_pieces = 3
|
||||
|
||||
with Model.define_operators({">>": chain, "**": clone}):
|
||||
convolution = Residual((ExtractWindow(nW=1) >>
|
||||
LN(Maxout(hidden_with, hidden_with * 3, pieces=cnn_maxout_pieces))))
|
||||
|
||||
encoder = SpacyVectors \
|
||||
>> with_flatten(LN(Maxout(hidden_with, in_width)) >> convolution ** conv_depth, pad=conv_depth) \
|
||||
>> flatten_add_lengths \
|
||||
>> ParametricAttention(hidden_with)\
|
||||
>> Pooling(mean_pool) \
|
||||
>> Residual(zero_init(Maxout(hidden_with, hidden_with))) \
|
||||
>> zero_init(Affine(end_width, hidden_with, drop_factor=0.0))
|
||||
|
||||
# TODO: ReLu or LN(Maxout) ?
|
||||
# sum_pool or mean_pool ?
|
||||
|
||||
return encoder
|
||||
|
||||
def _begin_training(self):
|
||||
self.sgd_article = create_default_optimizer(self.article_encoder.ops)
|
||||
self.sgd_article.learn_rate = self.LEARN_RATE
|
||||
self.sgd_article.L2 = self.L2
|
||||
|
||||
self.sgd_sent = create_default_optimizer(self.sent_encoder.ops)
|
||||
self.sgd_sent.learn_rate = self.LEARN_RATE
|
||||
self.sgd_sent.L2 = self.L2
|
||||
|
||||
self.sgd_cont = create_default_optimizer(self.cont_encoder.ops)
|
||||
self.sgd_cont.learn_rate = self.LEARN_RATE
|
||||
self.sgd_cont.L2 = self.L2
|
||||
|
||||
self.sgd_desc = create_default_optimizer(self.desc_encoder.ops)
|
||||
self.sgd_desc.learn_rate = self.LEARN_RATE
|
||||
self.sgd_desc.L2 = self.L2
|
||||
|
||||
def get_loss(self, pred, gold, targets):
|
||||
loss, gradients = self.get_cossim_loss(pred, gold, targets)
|
||||
return loss, gradients
|
||||
|
||||
def get_cossim_loss(self, yh, y, t):
|
||||
# Add a small constant to avoid 0 vectors
|
||||
# print()
|
||||
# print("yh", yh)
|
||||
# print("y", y)
|
||||
# print("t", t)
|
||||
yh = yh + 1e-8
|
||||
y = y + 1e-8
|
||||
# https://math.stackexchange.com/questions/1923613/partial-derivative-of-cosine-similarity
|
||||
xp = get_array_module(yh)
|
||||
norm_yh = xp.linalg.norm(yh, axis=1, keepdims=True)
|
||||
norm_y = xp.linalg.norm(y, axis=1, keepdims=True)
|
||||
mul_norms = norm_yh * norm_y
|
||||
cos = (yh * y).sum(axis=1, keepdims=True) / mul_norms
|
||||
# print("cos", cos)
|
||||
d_yh = (y / mul_norms) - (cos * (yh / norm_yh ** 2))
|
||||
# print("abs", xp.abs(cos - t))
|
||||
loss = xp.abs(cos - t).sum()
|
||||
# print("loss", loss)
|
||||
# print("d_yh", d_yh)
|
||||
inverse = np.asarray([int(t[i][0]) * d_yh[i] for i in range(len(t))])
|
||||
# print("inverse", inverse)
|
||||
return loss, -inverse
|
||||
|
||||
def update(self, entity_clusters, golds, descs, art_texts, arts, sent_texts, sents):
|
||||
arts_list = list()
|
||||
sents_list = list()
|
||||
descs_list = list()
|
||||
targets = list()
|
||||
|
||||
for cluster, entities in entity_clusters.items():
|
||||
art = art_texts[arts[cluster]]
|
||||
sent = sent_texts[sents[cluster]]
|
||||
for e in entities:
|
||||
if golds[e]:
|
||||
arts_list.append(art)
|
||||
sents_list.append(sent)
|
||||
descs_list.append(descs[e])
|
||||
targets.append([1])
|
||||
# else:
|
||||
# arts_list.append(art)
|
||||
# sents_list.append(sent)
|
||||
# descs_list.append(descs[e])
|
||||
# targets.append([-1])
|
||||
|
||||
desc_docs = self.nlp.pipe(descs_list)
|
||||
desc_encodings, bp_desc = self.desc_encoder.begin_update(desc_docs, drop=self.DROP)
|
||||
|
||||
art_docs = self.nlp.pipe(arts_list)
|
||||
sent_docs = self.nlp.pipe(sents_list)
|
||||
|
||||
doc_encodings, bp_doc = self.article_encoder.begin_update(art_docs, drop=self.DROP)
|
||||
sent_encodings, bp_sent = self.sent_encoder.begin_update(sent_docs, drop=self.DROP)
|
||||
|
||||
concat_encodings = [list(doc_encodings[i]) + list(sent_encodings[i]) for i in
|
||||
range(len(targets))]
|
||||
cont_encodings, bp_cont = self.cont_encoder.begin_update(np.asarray(concat_encodings), drop=self.DROP)
|
||||
|
||||
loss, cont_gradient = self.get_loss(cont_encodings, desc_encodings, targets)
|
||||
|
||||
# loss, desc_gradient = self.get_loss(desc_encodings, cont_encodings, targets)
|
||||
# cont_gradient = cont_gradient / 2
|
||||
# desc_gradient = desc_gradient / 2
|
||||
# bp_desc(desc_gradient, sgd=self.sgd_desc)
|
||||
|
||||
if self.PRINT_BATCH_LOSS:
|
||||
print("batch loss", loss)
|
||||
|
||||
context_gradient = bp_cont(cont_gradient, sgd=self.sgd_cont)
|
||||
|
||||
# gradient : concat (doc+sent) vs. desc
|
||||
sent_start = self.ARTICLE_WIDTH
|
||||
sent_gradients = list()
|
||||
doc_gradients = list()
|
||||
for x in context_gradient:
|
||||
doc_gradients.append(list(x[0:sent_start]))
|
||||
sent_gradients.append(list(x[sent_start:]))
|
||||
|
||||
bp_doc(doc_gradients, sgd=self.sgd_article)
|
||||
bp_sent(sent_gradients, sgd=self.sgd_sent)
|
||||
|
||||
def _get_training_data(self, training_dir, id_to_descr, dev, limit, to_print):
|
||||
correct_entries, incorrect_entries = training_set_creator.read_training_entities(training_output=training_dir)
|
||||
|
||||
entities_by_cluster = dict()
|
||||
gold_by_entity = dict()
|
||||
desc_by_entity = dict()
|
||||
article_by_cluster = dict()
|
||||
text_by_article = dict()
|
||||
sentence_by_cluster = dict()
|
||||
text_by_sentence = dict()
|
||||
sentence_by_text = dict()
|
||||
|
||||
cnt = 0
|
||||
next_entity_nr = 1
|
||||
next_sent_nr = 1
|
||||
files = listdir(training_dir)
|
||||
shuffle(files)
|
||||
for f in files:
|
||||
if not limit or cnt < limit:
|
||||
if dev == run_el.is_dev(f):
|
||||
article_id = f.replace(".txt", "")
|
||||
if cnt % 500 == 0 and to_print:
|
||||
print(datetime.datetime.now(), "processed", cnt, "files in the training dataset")
|
||||
|
||||
try:
|
||||
# parse the article text
|
||||
with open(os.path.join(training_dir, f), mode="r", encoding='utf8') as file:
|
||||
text = file.read()
|
||||
article_doc = self.nlp(text)
|
||||
truncated_text = text[0:min(self.DOC_CUTOFF, len(text))]
|
||||
text_by_article[article_id] = truncated_text
|
||||
|
||||
# process all positive and negative entities, collect all relevant mentions in this article
|
||||
for mention, entity_pos in correct_entries[article_id].items():
|
||||
cluster = article_id + "_" + mention
|
||||
descr = id_to_descr.get(entity_pos)
|
||||
entities = set()
|
||||
if descr:
|
||||
entity = "E_" + str(next_entity_nr) + "_" + cluster
|
||||
next_entity_nr += 1
|
||||
gold_by_entity[entity] = 1
|
||||
desc_by_entity[entity] = descr
|
||||
entities.add(entity)
|
||||
|
||||
entity_negs = incorrect_entries[article_id][mention]
|
||||
for entity_neg in entity_negs:
|
||||
descr = id_to_descr.get(entity_neg)
|
||||
if descr:
|
||||
entity = "E_" + str(next_entity_nr) + "_" + cluster
|
||||
next_entity_nr += 1
|
||||
gold_by_entity[entity] = 0
|
||||
desc_by_entity[entity] = descr
|
||||
entities.add(entity)
|
||||
|
||||
found_matches = 0
|
||||
if len(entities) > 1:
|
||||
entities_by_cluster[cluster] = entities
|
||||
|
||||
# find all matches in the doc for the mentions
|
||||
# TODO: fix this - doesn't look like all entities are found
|
||||
matcher = PhraseMatcher(self.nlp.vocab)
|
||||
patterns = list(self.nlp.tokenizer.pipe([mention]))
|
||||
|
||||
matcher.add("TerminologyList", None, *patterns)
|
||||
matches = matcher(article_doc)
|
||||
|
||||
# store sentences
|
||||
for match_id, start, end in matches:
|
||||
span = article_doc[start:end]
|
||||
if mention == span.text:
|
||||
found_matches += 1
|
||||
sent_text = span.sent.text
|
||||
sent_nr = sentence_by_text.get(sent_text, None)
|
||||
if sent_nr is None:
|
||||
sent_nr = "S_" + str(next_sent_nr) + article_id
|
||||
next_sent_nr += 1
|
||||
text_by_sentence[sent_nr] = sent_text
|
||||
sentence_by_text[sent_text] = sent_nr
|
||||
article_by_cluster[cluster] = article_id
|
||||
sentence_by_cluster[cluster] = sent_nr
|
||||
|
||||
if found_matches == 0:
|
||||
# print("Could not find neg instances or sentence matches for", mention, "in", article_id)
|
||||
entities_by_cluster.pop(cluster, None)
|
||||
article_by_cluster.pop(cluster, None)
|
||||
sentence_by_cluster.pop(cluster, None)
|
||||
for entity in entities:
|
||||
gold_by_entity.pop(entity, None)
|
||||
desc_by_entity.pop(entity, None)
|
||||
cnt += 1
|
||||
except:
|
||||
print("Problem parsing article", article_id)
|
||||
|
||||
if to_print:
|
||||
print()
|
||||
print("Processed", cnt, "training articles, dev=" + str(dev))
|
||||
print()
|
||||
return entities_by_cluster, gold_by_entity, desc_by_entity, article_by_cluster, text_by_article, \
|
||||
sentence_by_cluster, text_by_sentence
|
||||
|
|
@ -5,8 +5,8 @@ import random
|
|||
|
||||
from spacy.util import minibatch, compounding
|
||||
|
||||
from examples.pipeline.wiki_entity_linking import wikipedia_processor as wp, kb_creator, training_set_creator, run_el
|
||||
from examples.pipeline.wiki_entity_linking.kb_creator import DESC_WIDTH
|
||||
from bin.wiki_entity_linking import training_set_creator, kb_creator, wikipedia_processor as wp
|
||||
from bin.wiki_entity_linking.kb_creator import DESC_WIDTH
|
||||
|
||||
import spacy
|
||||
from spacy.kb import KnowledgeBase
|
||||
|
@ -30,9 +30,11 @@ TRAINING_DIR = 'C:/Users/Sofie/Documents/data/wikipedia/training_data_nel/'
|
|||
MAX_CANDIDATES = 10
|
||||
MIN_ENTITY_FREQ = 20
|
||||
MIN_PAIR_OCC = 5
|
||||
DOC_SENT_CUTOFF = 2
|
||||
|
||||
EPOCHS = 10
|
||||
DROPOUT = 0.1
|
||||
LEARN_RATE = 0.005
|
||||
L2 = 1e-6
|
||||
|
||||
|
||||
def run_pipeline():
|
||||
|
@ -40,7 +42,6 @@ def run_pipeline():
|
|||
print()
|
||||
nlp_1 = spacy.load('en_core_web_lg')
|
||||
nlp_2 = None
|
||||
kb_1 = None
|
||||
kb_2 = None
|
||||
|
||||
# one-time methods to create KB and write to file
|
||||
|
@ -114,7 +115,7 @@ def run_pipeline():
|
|||
|
||||
# test KB
|
||||
if to_test_kb:
|
||||
test_kb(kb_2)
|
||||
check_kb(kb_2)
|
||||
print()
|
||||
|
||||
# STEP 5: create a training dataset from WP
|
||||
|
@ -122,19 +123,21 @@ def run_pipeline():
|
|||
print("STEP 5: create training dataset", datetime.datetime.now())
|
||||
training_set_creator.create_training(entity_def_input=ENTITY_DEFS, training_output=TRAINING_DIR)
|
||||
|
||||
# STEP 6: create the entity linking pipe
|
||||
el_pipe = nlp_2.create_pipe(name='entity_linker', config={"doc_cutoff": DOC_SENT_CUTOFF})
|
||||
# STEP 6: create and train the entity linking pipe
|
||||
el_pipe = nlp_2.create_pipe(name='entity_linker', config={})
|
||||
el_pipe.set_kb(kb_2)
|
||||
nlp_2.add_pipe(el_pipe, last=True)
|
||||
|
||||
other_pipes = [pipe for pipe in nlp_2.pipe_names if pipe != "entity_linker"]
|
||||
with nlp_2.disable_pipes(*other_pipes): # only train Entity Linking
|
||||
nlp_2.begin_training()
|
||||
optimizer = nlp_2.begin_training()
|
||||
optimizer.learn_rate = LEARN_RATE
|
||||
optimizer.L2 = L2
|
||||
|
||||
if train_pipe:
|
||||
print("STEP 6: training Entity Linking pipe", datetime.datetime.now())
|
||||
train_limit = 25000
|
||||
dev_limit = 1000
|
||||
dev_limit = 5000
|
||||
|
||||
train_data = training_set_creator.read_training(nlp=nlp_2,
|
||||
training_dir=TRAINING_DIR,
|
||||
|
@ -144,6 +147,14 @@ def run_pipeline():
|
|||
print("Training on", len(train_data), "articles")
|
||||
print()
|
||||
|
||||
dev_data = training_set_creator.read_training(nlp=nlp_2,
|
||||
training_dir=TRAINING_DIR,
|
||||
dev=True,
|
||||
limit=dev_limit)
|
||||
|
||||
print("Dev testing on", len(dev_data), "articles")
|
||||
print()
|
||||
|
||||
if not train_data:
|
||||
print("Did not find any training data")
|
||||
|
||||
|
@ -161,53 +172,55 @@ def run_pipeline():
|
|||
nlp_2.update(
|
||||
docs,
|
||||
golds,
|
||||
sgd=optimizer,
|
||||
drop=DROPOUT,
|
||||
losses=losses,
|
||||
)
|
||||
batchnr += 1
|
||||
except Exception as e:
|
||||
print("Error updating batch:", e)
|
||||
raise(e)
|
||||
|
||||
if batchnr > 0:
|
||||
losses['entity_linker'] = losses['entity_linker'] / batchnr
|
||||
print("Epoch, train loss", itn, round(losses['entity_linker'], 2))
|
||||
|
||||
dev_data = training_set_creator.read_training(nlp=nlp_2,
|
||||
training_dir=TRAINING_DIR,
|
||||
dev=True,
|
||||
limit=dev_limit)
|
||||
|
||||
print()
|
||||
print("Dev testing on", len(dev_data), "articles")
|
||||
with el_pipe.model.use_params(optimizer.averages):
|
||||
el_pipe.context_weight = 1
|
||||
el_pipe.prior_weight = 0
|
||||
dev_acc_context, dev_acc_context_dict = _measure_accuracy(dev_data, el_pipe)
|
||||
losses['entity_linker'] = losses['entity_linker'] / batchnr
|
||||
print("Epoch, train loss", itn, round(losses['entity_linker'], 2),
|
||||
" / dev acc context avg", round(dev_acc_context, 3))
|
||||
|
||||
# STEP 7: measure the performance of our trained pipe on an independent dev set
|
||||
if len(dev_data) and measure_performance:
|
||||
print()
|
||||
print("STEP 7: performance measurement of Entity Linking pipe", datetime.datetime.now())
|
||||
print()
|
||||
|
||||
acc_random, acc_random_by_label, acc_prior, acc_prior_by_label, acc_oracle, acc_oracle_by_label = _measure_baselines(dev_data, kb_2)
|
||||
print("dev acc oracle:", round(acc_oracle, 3), [(x, round(y, 3)) for x, y in acc_oracle_by_label.items()])
|
||||
print("dev acc random:", round(acc_random, 3), [(x, round(y, 3)) for x, y in acc_random_by_label.items()])
|
||||
print("dev acc prior:", round(acc_prior, 3), [(x, round(y, 3)) for x, y in acc_prior_by_label.items()])
|
||||
acc_r, acc_r_by_label, acc_p, acc_p_by_label, acc_o, acc_o_by_label = _measure_baselines(dev_data, kb_2)
|
||||
print("dev acc oracle:", round(acc_o, 3), [(x, round(y, 3)) for x, y in acc_o_by_label.items()])
|
||||
print("dev acc random:", round(acc_r, 3), [(x, round(y, 3)) for x, y in acc_r_by_label.items()])
|
||||
print("dev acc prior:", round(acc_p, 3), [(x, round(y, 3)) for x, y in acc_p_by_label.items()])
|
||||
|
||||
# print(" measuring accuracy 1-1")
|
||||
el_pipe.context_weight = 1
|
||||
el_pipe.prior_weight = 1
|
||||
dev_acc_combo, dev_acc_combo_dict = _measure_accuracy(dev_data, el_pipe)
|
||||
print("dev acc combo:", round(dev_acc_combo, 3), [(x, round(y, 3)) for x, y in dev_acc_combo_dict.items()])
|
||||
with el_pipe.model.use_params(optimizer.averages):
|
||||
# measuring combined accuracy (prior + context)
|
||||
el_pipe.context_weight = 1
|
||||
el_pipe.prior_weight = 1
|
||||
dev_acc_combo, dev_acc_combo_dict = _measure_accuracy(dev_data, el_pipe)
|
||||
print("dev acc combo avg:", round(dev_acc_combo, 3),
|
||||
[(x, round(y, 3)) for x, y in dev_acc_combo_dict.items()])
|
||||
|
||||
# using only context
|
||||
el_pipe.context_weight = 1
|
||||
el_pipe.prior_weight = 0
|
||||
dev_acc_context, dev_acc_1_0_dict = _measure_accuracy(dev_data, el_pipe)
|
||||
print("dev acc context:", round(dev_acc_context, 3), [(x, round(y, 3)) for x, y in dev_acc_1_0_dict.items()])
|
||||
print()
|
||||
# using only context
|
||||
el_pipe.context_weight = 1
|
||||
el_pipe.prior_weight = 0
|
||||
dev_acc_context, dev_acc_context_dict = _measure_accuracy(dev_data, el_pipe)
|
||||
print("dev acc context avg:", round(dev_acc_context, 3),
|
||||
[(x, round(y, 3)) for x, y in dev_acc_context_dict.items()])
|
||||
print()
|
||||
|
||||
# reset for follow-up tests
|
||||
el_pipe.context_weight = 1
|
||||
el_pipe.prior_weight = 1
|
||||
|
||||
# STEP 8: apply the EL pipe on a toy example
|
||||
if to_test_pipeline:
|
||||
print()
|
||||
print("STEP 8: applying Entity Linking to toy example", datetime.datetime.now())
|
||||
|
@ -215,6 +228,7 @@ def run_pipeline():
|
|||
run_el_toy_example(nlp=nlp_2)
|
||||
print()
|
||||
|
||||
# STEP 9: write the NLP pipeline (including entity linker) to file
|
||||
if to_write_nlp:
|
||||
print()
|
||||
print("STEP 9: testing NLP IO", datetime.datetime.now())
|
||||
|
@ -225,6 +239,7 @@ def run_pipeline():
|
|||
print("reading from", NLP_2_DIR)
|
||||
nlp_3 = spacy.load(NLP_2_DIR)
|
||||
|
||||
# verify that the IO has gone correctly
|
||||
if to_read_nlp:
|
||||
print()
|
||||
print("running toy example with NLP 2")
|
||||
|
@ -272,6 +287,7 @@ def _measure_accuracy(data, el_pipe):
|
|||
|
||||
|
||||
def _measure_baselines(data, kb):
|
||||
# Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound
|
||||
random_correct_by_label = dict()
|
||||
random_incorrect_by_label = dict()
|
||||
|
||||
|
@ -362,7 +378,7 @@ def calculate_acc(correct_by_label, incorrect_by_label):
|
|||
return acc, acc_by_label
|
||||
|
||||
|
||||
def test_kb(kb):
|
||||
def check_kb(kb):
|
||||
for mention in ("Bush", "Douglas Adams", "Homer", "Brazil", "China"):
|
||||
candidates = kb.get_candidates(mention)
|
||||
|
||||
|
@ -384,7 +400,7 @@ def run_el_toy_example(nlp):
|
|||
print()
|
||||
|
||||
# Q4426480 is her husband
|
||||
text = "Ada Lovelace was the countess of Lovelace. She is known for her programming work on the analytical engine. "\
|
||||
text = "Ada Lovelace was the countess of Lovelace. She's known for her programming work on the analytical engine. "\
|
||||
"She loved her husband William King dearly. "
|
||||
doc = nlp(text)
|
||||
print(text)
|
||||
|
@ -393,7 +409,7 @@ def run_el_toy_example(nlp):
|
|||
print()
|
||||
|
||||
# Q3568763 is her tutor
|
||||
text = "Ada Lovelace was the countess of Lovelace. She is known for her programming work on the analytical engine. "\
|
||||
text = "Ada Lovelace was the countess of Lovelace. She's known for her programming work on the analytical engine. "\
|
||||
"She was tutored by her favorite physics tutor William King."
|
||||
doc = nlp(text)
|
||||
print(text)
|
|
@ -661,10 +661,11 @@ def build_nel_encoder(in_width, hidden_width, end_width, **cfg):
|
|||
LN(Maxout(hidden_width, hidden_width * 3, pieces=cnn_maxout_pieces))))
|
||||
|
||||
encoder = SpacyVectors \
|
||||
>> with_flatten(LN(Maxout(hidden_width, in_width)) >> convolution ** conv_depth, pad=conv_depth) \
|
||||
>> with_flatten(Affine(hidden_width, in_width))\
|
||||
>> with_flatten(LN(Maxout(hidden_width, hidden_width)) >> convolution ** conv_depth, pad=conv_depth) \
|
||||
>> flatten_add_lengths \
|
||||
>> ParametricAttention(hidden_width) \
|
||||
>> Pooling(mean_pool) \
|
||||
>> Pooling(sum_pool) \
|
||||
>> Residual(zero_init(Maxout(hidden_width, hidden_width))) \
|
||||
>> zero_init(Affine(end_width, hidden_width, drop_factor=0.0))
|
||||
|
||||
|
|
|
@ -1078,33 +1078,19 @@ class EntityLinker(Pipe):
|
|||
raise ValueError("entity_width not found")
|
||||
|
||||
embed_width = cfg.get("embed_width", 300)
|
||||
hidden_width = cfg.get("hidden_width", 32)
|
||||
entity_width = cfg.get("entity_width") # no default because this needs to correspond with the KB
|
||||
sent_width = entity_width
|
||||
hidden_width = cfg.get("hidden_width", 128)
|
||||
|
||||
# no default because this needs to correspond with the KB entity length
|
||||
sent_width = cfg.get("entity_width")
|
||||
|
||||
model = build_nel_encoder(in_width=embed_width, hidden_width=hidden_width, end_width=sent_width, **cfg)
|
||||
|
||||
# dimension of the mention encoder needs to match the dimension of the entity encoder
|
||||
# article_width = cfg.get("article_width", 128)
|
||||
# sent_width = cfg.get("sent_width", 64)
|
||||
# article_encoder = build_nel_encoder(in_width=embed_width, hidden_width=hidden_width, end_width=article_width, **cfg)
|
||||
# mention_width = article_width + sent_width
|
||||
# mention_encoder = Affine(entity_width, mention_width, drop_factor=0.0)
|
||||
# return article_encoder, sent_encoder, mention_encoder
|
||||
|
||||
return model
|
||||
|
||||
def __init__(self, **cfg):
|
||||
# self.article_encoder = True
|
||||
# self.sent_encoder = True
|
||||
# self.mention_encoder = True
|
||||
self.model = True
|
||||
self.kb = None
|
||||
self.cfg = dict(cfg)
|
||||
self.doc_cutoff = self.cfg.get("doc_cutoff", 5)
|
||||
# self.sgd_article = None
|
||||
# self.sgd_sent = None
|
||||
# self.sgd_mention = None
|
||||
|
||||
def set_kb(self, kb):
|
||||
self.kb = kb
|
||||
|
@ -1131,13 +1117,6 @@ class EntityLinker(Pipe):
|
|||
sgd = self.create_optimizer()
|
||||
return sgd
|
||||
|
||||
# if self.mention_encoder is True:
|
||||
# self.article_encoder, self.sent_encoder, self.mention_encoder = self.Model(**self.cfg)
|
||||
# self.sgd_article = create_default_optimizer(self.article_encoder.ops)
|
||||
# self.sgd_sent = create_default_optimizer(self.sent_encoder.ops)
|
||||
# self.sgd_mention = create_default_optimizer(self.mention_encoder.ops)
|
||||
# return self.sgd_article
|
||||
|
||||
def update(self, docs, golds, state=None, drop=0.0, sgd=None, losses=None):
|
||||
self.require_model()
|
||||
self.require_kb()
|
||||
|
@ -1166,15 +1145,11 @@ class EntityLinker(Pipe):
|
|||
mention = doc.text[start:end]
|
||||
sent_start = 0
|
||||
sent_end = len(doc)
|
||||
first_par_end = len(doc)
|
||||
for index, sent in enumerate(doc.sents):
|
||||
if start >= sent.start_char and end <= sent.end_char:
|
||||
sent_start = sent.start
|
||||
sent_end = sent.end
|
||||
if index == self.doc_cutoff-1:
|
||||
first_par_end = sent.end
|
||||
sentence = doc[sent_start:sent_end].as_doc()
|
||||
first_par = doc[0:first_par_end].as_doc()
|
||||
|
||||
candidates = self.kb.get_candidates(mention)
|
||||
for c in candidates:
|
||||
|
@ -1184,32 +1159,15 @@ class EntityLinker(Pipe):
|
|||
prior_prob = c.prior_prob
|
||||
entity_encoding = c.entity_vector
|
||||
entity_encodings.append(entity_encoding)
|
||||
# article_docs.append(first_par)
|
||||
sentence_docs.append(sentence)
|
||||
|
||||
if len(entity_encodings) > 0:
|
||||
# doc_encodings, bp_doc = self.article_encoder.begin_update(article_docs, drop=drop)
|
||||
# sent_encodings, bp_sent = self.sent_encoder.begin_update(sentence_docs, drop=drop)
|
||||
|
||||
# concat_encodings = [list(doc_encodings[i]) + list(sent_encodings[i]) for i in range(len(article_docs))]
|
||||
# mention_encodings, bp_mention = self.mention_encoder.begin_update(np.asarray(concat_encodings), drop=drop)
|
||||
|
||||
sent_encodings, bp_sent = self.model.begin_update(sentence_docs, drop=drop)
|
||||
entity_encodings = np.asarray(entity_encodings, dtype=np.float32)
|
||||
|
||||
loss, d_scores = self.get_loss(scores=sent_encodings, golds=entity_encodings, docs=None)
|
||||
bp_sent(d_scores, sgd=sgd)
|
||||
|
||||
# gradient : concat (doc+sent) vs. desc
|
||||
# sent_start = self.article_encoder.nO
|
||||
# sent_gradients = list()
|
||||
# doc_gradients = list()
|
||||
# for x in mention_gradient:
|
||||
# doc_gradients.append(list(x[0:sent_start]))
|
||||
# sent_gradients.append(list(x[sent_start:]))
|
||||
# bp_doc(doc_gradients, sgd=self.sgd_article)
|
||||
# bp_sent(sent_gradients, sgd=self.sgd_sent)
|
||||
|
||||
if losses is not None:
|
||||
losses[self.name] += loss
|
||||
return loss
|
||||
|
@ -1264,21 +1222,9 @@ class EntityLinker(Pipe):
|
|||
|
||||
for i, doc in enumerate(docs):
|
||||
if len(doc) > 0:
|
||||
first_par_end = len(doc)
|
||||
for index, sent in enumerate(doc.sents):
|
||||
if index == self.doc_cutoff-1:
|
||||
first_par_end = sent.end
|
||||
first_par = doc[0:first_par_end].as_doc()
|
||||
|
||||
# doc_encoding = self.article_encoder([first_par])
|
||||
for ent in doc.ents:
|
||||
sent_doc = ent.sent.as_doc()
|
||||
if len(sent_doc) > 0:
|
||||
# sent_encoding = self.sent_encoder([sent_doc])
|
||||
# concat_encoding = [list(doc_encoding[0]) + list(sent_encoding[0])]
|
||||
# mention_encoding = self.mention_encoder(np.asarray([concat_encoding[0]]))
|
||||
# mention_enc_t = np.transpose(mention_encoding)
|
||||
|
||||
sent_encoding = self.model([sent_doc])
|
||||
sent_enc_t = np.transpose(sent_encoding)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user