sentence encoder only (removing article/mention encoder)

This commit is contained in:
svlandeg 2019-06-18 00:05:47 +02:00
parent 6332af40de
commit ffae7d3555
3 changed files with 95 additions and 90 deletions

View File

@ -294,7 +294,6 @@ def read_training(nlp, training_dir, dev, limit):
# we assume the data is written sequentially
current_article_id = None
current_doc = None
gold_entities = list()
ents_by_offset = dict()
skip_articles = set()
total_entities = 0
@ -302,8 +301,6 @@ def read_training(nlp, training_dir, dev, limit):
with open(entityfile_loc, mode='r', encoding='utf8') as file:
for line in file:
if not limit or len(data) < limit:
if len(data) > 0 and len(data) % 50 == 0:
print("Read", total_entities, "entities in", len(data), "articles")
fields = line.replace('\n', "").split(sep='|')
article_id = fields[0]
alias = fields[1]
@ -313,34 +310,42 @@ def read_training(nlp, training_dir, dev, limit):
if dev == is_dev(article_id) and article_id != "article_id" and article_id not in skip_articles:
if not current_doc or (current_article_id != article_id):
# store the data from the previous article
if gold_entities and current_doc:
gold = GoldParse(doc=current_doc, links=gold_entities)
data.append((current_doc, gold))
total_entities += len(gold_entities)
# parse the new article text
file_name = article_id + ".txt"
try:
with open(os.path.join(training_dir, file_name), mode="r", encoding='utf8') as f:
text = f.read()
current_doc = nlp(text)
for ent in current_doc.ents:
ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)] = ent.text
if len(text) < 30000: # threshold for convenience / speed of processing
current_doc = nlp(text)
current_article_id = article_id
ents_by_offset = dict()
for ent in current_doc.ents:
ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)] = ent
else:
skip_articles.add(current_article_id)
current_doc = None
except Exception as e:
print("Problem parsing article", article_id, e)
current_article_id = article_id
gold_entities = list()
# repeat checking this condition in case an exception was thrown
if current_doc and (current_article_id == article_id):
found_ent = ents_by_offset.get(start + "_" + end, None)
if found_ent:
if found_ent != alias:
if found_ent.text != alias:
skip_articles.add(current_article_id)
current_doc = None
else:
gold_entities.append((int(start), int(end), wp_title))
sent = found_ent.sent.as_doc()
# currently feeding the gold data one entity per sentence at a time
gold_start = int(start) - found_ent.sent.start_char
gold_end = int(end) - found_ent.sent.start_char
gold_entities = list()
gold_entities.append((gold_start, gold_end, wp_title))
gold = GoldParse(doc=current_doc, links=gold_entities)
data.append((sent, gold))
total_entities += 1
if len(data) % 500 == 0:
print(" -read", total_entities, "entities")
print("Read", total_entities, "entities in", len(data), "articles")
print(" -read", total_entities, "entities")
return data

View File

@ -9,7 +9,6 @@ from examples.pipeline.wiki_entity_linking import wikipedia_processor as wp, kb_
from examples.pipeline.wiki_entity_linking.kb_creator import DESC_WIDTH
import spacy
from spacy.vocab import Vocab
from spacy.kb import KnowledgeBase
import datetime
@ -64,8 +63,8 @@ def run_pipeline():
to_test_pipeline = True
# write the NLP object, read back in and test again
to_write_nlp = True
to_read_nlp = True
to_write_nlp = False
to_read_nlp = False
# STEP 1 : create prior probabilities from WP
# run only once !
@ -134,8 +133,8 @@ def run_pipeline():
if train_pipe:
print("STEP 6: training Entity Linking pipe", datetime.datetime.now())
train_limit = 5
dev_limit = 2
train_limit = 25000
dev_limit = 1000
train_data = training_set_creator.read_training(nlp=nlp_2,
training_dir=TRAINING_DIR,
@ -345,7 +344,11 @@ def calculate_acc(correct_by_label, incorrect_by_label):
acc_by_label = dict()
total_correct = 0
total_incorrect = 0
for label, correct in correct_by_label.items():
all_keys = set()
all_keys.update(correct_by_label.keys())
all_keys.update(incorrect_by_label.keys())
for label in sorted(all_keys):
correct = correct_by_label.get(label, 0)
incorrect = incorrect_by_label.get(label, 0)
total_correct += correct
total_incorrect += incorrect

View File

@ -1079,36 +1079,39 @@ class EntityLinker(Pipe):
embed_width = cfg.get("embed_width", 300)
hidden_width = cfg.get("hidden_width", 32)
article_width = cfg.get("article_width", 128)
sent_width = cfg.get("sent_width", 64)
entity_width = cfg.get("entity_width") # no default because this needs to correspond with the KB
sent_width = entity_width
article_encoder = build_nel_encoder(in_width=embed_width, hidden_width=hidden_width, end_width=article_width, **cfg)
sent_encoder = build_nel_encoder(in_width=embed_width, hidden_width=hidden_width, end_width=sent_width, **cfg)
model = build_nel_encoder(in_width=embed_width, hidden_width=hidden_width, end_width=sent_width, **cfg)
# dimension of the mention encoder needs to match the dimension of the entity encoder
mention_width = article_width + sent_width
mention_encoder = Affine(entity_width, mention_width, drop_factor=0.0)
# article_width = cfg.get("article_width", 128)
# sent_width = cfg.get("sent_width", 64)
# article_encoder = build_nel_encoder(in_width=embed_width, hidden_width=hidden_width, end_width=article_width, **cfg)
# mention_width = article_width + sent_width
# mention_encoder = Affine(entity_width, mention_width, drop_factor=0.0)
# return article_encoder, sent_encoder, mention_encoder
return article_encoder, sent_encoder, mention_encoder
return model
def __init__(self, **cfg):
self.article_encoder = True
self.sent_encoder = True
self.mention_encoder = True
# self.article_encoder = True
# self.sent_encoder = True
# self.mention_encoder = True
self.model = True
self.kb = None
self.cfg = dict(cfg)
self.doc_cutoff = self.cfg.get("doc_cutoff", 5)
self.sgd_article = None
self.sgd_sent = None
self.sgd_mention = None
# self.sgd_article = None
# self.sgd_sent = None
# self.sgd_mention = None
def set_kb(self, kb):
self.kb = kb
def require_model(self):
# Raise an error if the component's model is not initialized.
if getattr(self, "mention_encoder", None) in (None, True, False):
if getattr(self, "model", None) in (None, True, False):
raise ValueError(Errors.E109.format(name=self.name))
def require_kb(self):
@ -1121,12 +1124,19 @@ class EntityLinker(Pipe):
self.require_kb()
self.cfg["entity_width"] = self.kb.entity_vector_length
if self.mention_encoder is True:
self.article_encoder, self.sent_encoder, self.mention_encoder = self.Model(**self.cfg)
self.sgd_article = create_default_optimizer(self.article_encoder.ops)
self.sgd_sent = create_default_optimizer(self.sent_encoder.ops)
self.sgd_mention = create_default_optimizer(self.mention_encoder.ops)
return self.sgd_article
if self.model is True:
self.model = self.Model(**self.cfg)
if sgd is None:
sgd = self.create_optimizer()
return sgd
# if self.mention_encoder is True:
# self.article_encoder, self.sent_encoder, self.mention_encoder = self.Model(**self.cfg)
# self.sgd_article = create_default_optimizer(self.article_encoder.ops)
# self.sgd_sent = create_default_optimizer(self.sent_encoder.ops)
# self.sgd_mention = create_default_optimizer(self.mention_encoder.ops)
# return self.sgd_article
def update(self, docs, golds, state=None, drop=0.0, sgd=None, losses=None):
self.require_model()
@ -1146,7 +1156,7 @@ class EntityLinker(Pipe):
docs = [docs]
golds = [golds]
article_docs = list()
# article_docs = list()
sentence_docs = list()
entity_encodings = list()
@ -1173,34 +1183,32 @@ class EntityLinker(Pipe):
if kb_id == gold_kb:
prior_prob = c.prior_prob
entity_encoding = c.entity_vector
entity_encodings.append(entity_encoding)
article_docs.append(first_par)
# article_docs.append(first_par)
sentence_docs.append(sentence)
if len(entity_encodings) > 0:
doc_encodings, bp_doc = self.article_encoder.begin_update(article_docs, drop=drop)
sent_encodings, bp_sent = self.sent_encoder.begin_update(sentence_docs, drop=drop)
# doc_encodings, bp_doc = self.article_encoder.begin_update(article_docs, drop=drop)
# sent_encodings, bp_sent = self.sent_encoder.begin_update(sentence_docs, drop=drop)
concat_encodings = [list(doc_encodings[i]) + list(sent_encodings[i]) for i in
range(len(article_docs))]
mention_encodings, bp_mention = self.mention_encoder.begin_update(np.asarray(concat_encodings), drop=drop)
# concat_encodings = [list(doc_encodings[i]) + list(sent_encodings[i]) for i in range(len(article_docs))]
# mention_encodings, bp_mention = self.mention_encoder.begin_update(np.asarray(concat_encodings), drop=drop)
sent_encodings, bp_sent = self.model.begin_update(sentence_docs, drop=drop)
entity_encodings = np.asarray(entity_encodings, dtype=np.float32)
loss, d_scores = self.get_loss(scores=mention_encodings, golds=entity_encodings, docs=None)
mention_gradient = bp_mention(d_scores, sgd=self.sgd_mention)
loss, d_scores = self.get_loss(scores=sent_encodings, golds=entity_encodings, docs=None)
bp_sent(d_scores, sgd=sgd)
# gradient : concat (doc+sent) vs. desc
sent_start = self.article_encoder.nO
sent_gradients = list()
doc_gradients = list()
for x in mention_gradient:
doc_gradients.append(list(x[0:sent_start]))
sent_gradients.append(list(x[sent_start:]))
bp_doc(doc_gradients, sgd=self.sgd_article)
bp_sent(sent_gradients, sgd=self.sgd_sent)
# sent_start = self.article_encoder.nO
# sent_gradients = list()
# doc_gradients = list()
# for x in mention_gradient:
# doc_gradients.append(list(x[0:sent_start]))
# sent_gradients.append(list(x[sent_start:]))
# bp_doc(doc_gradients, sgd=self.sgd_article)
# bp_sent(sent_gradients, sgd=self.sgd_sent)
if losses is not None:
losses[self.name] += loss
@ -1262,14 +1270,17 @@ class EntityLinker(Pipe):
first_par_end = sent.end
first_par = doc[0:first_par_end].as_doc()
doc_encoding = self.article_encoder([first_par])
# doc_encoding = self.article_encoder([first_par])
for ent in doc.ents:
sent_doc = ent.sent.as_doc()
if len(sent_doc) > 0:
sent_encoding = self.sent_encoder([sent_doc])
concat_encoding = [list(doc_encoding[0]) + list(sent_encoding[0])]
mention_encoding = self.mention_encoder(np.asarray([concat_encoding[0]]))
mention_enc_t = np.transpose(mention_encoding)
# sent_encoding = self.sent_encoder([sent_doc])
# concat_encoding = [list(doc_encoding[0]) + list(sent_encoding[0])]
# mention_encoding = self.mention_encoder(np.asarray([concat_encoding[0]]))
# mention_enc_t = np.transpose(mention_encoding)
sent_encoding = self.model([sent_doc])
sent_enc_t = np.transpose(sent_encoding)
candidates = self.kb.get_candidates(ent.text)
if candidates:
@ -1278,7 +1289,7 @@ class EntityLinker(Pipe):
prior_prob = c.prior_prob * self.prior_weight
kb_id = c.entity_
entity_encoding = c.entity_vector
sim = float(cosine(np.asarray([entity_encoding]), mention_enc_t)) * self.context_weight
sim = float(cosine(np.asarray([entity_encoding]), sent_enc_t)) * self.context_weight
score = prior_prob + sim - (prior_prob*sim) # put weights on the different factors ?
scores.append(score)
@ -1299,34 +1310,20 @@ class EntityLinker(Pipe):
serialize = OrderedDict()
serialize["cfg"] = lambda p: srsly.write_json(p, self.cfg)
serialize["kb"] = lambda p: self.kb.dump(p)
if self.mention_encoder not in (None, True, False):
serialize["article_encoder"] = lambda p: p.open("wb").write(self.article_encoder.to_bytes())
serialize["sent_encoder"] = lambda p: p.open("wb").write(self.sent_encoder.to_bytes())
serialize["mention_encoder"] = lambda p: p.open("wb").write(self.mention_encoder.to_bytes())
if self.model not in (None, True, False):
serialize["model"] = lambda p: p.open("wb").write(self.model.to_bytes())
exclude = util.get_serialization_exclude(serialize, exclude, kwargs)
util.to_disk(path, serialize, exclude)
def from_disk(self, path, exclude=tuple(), **kwargs):
def load_article_encoder(p):
if self.article_encoder is True:
self.article_encoder, _, _ = self.Model(**self.cfg)
self.article_encoder.from_bytes(p.open("rb").read())
def load_sent_encoder(p):
if self.sent_encoder is True:
_, self.sent_encoder, _ = self.Model(**self.cfg)
self.sent_encoder.from_bytes(p.open("rb").read())
def load_mention_encoder(p):
if self.mention_encoder is True:
_, _, self.mention_encoder = self.Model(**self.cfg)
self.mention_encoder.from_bytes(p.open("rb").read())
def load_model(p):
if self.model is True:
self.model = self.Model(**self.cfg)
self.model.from_bytes(p.open("rb").read())
deserialize = OrderedDict()
deserialize["cfg"] = lambda p: self.cfg.update(_load_cfg(p))
deserialize["article_encoder"] = load_article_encoder
deserialize["sent_encoder"] = load_sent_encoder
deserialize["mention_encoder"] = load_mention_encoder
deserialize["model"] = load_model
exclude = util.get_serialization_exclude(deserialize, exclude, kwargs)
util.from_disk(path, deserialize, exclude)
return self