mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-10 00:20:35 +03:00
sentence encoder only (removing article/mention encoder)
This commit is contained in:
parent
6332af40de
commit
ffae7d3555
|
@ -294,7 +294,6 @@ def read_training(nlp, training_dir, dev, limit):
|
||||||
# we assume the data is written sequentially
|
# we assume the data is written sequentially
|
||||||
current_article_id = None
|
current_article_id = None
|
||||||
current_doc = None
|
current_doc = None
|
||||||
gold_entities = list()
|
|
||||||
ents_by_offset = dict()
|
ents_by_offset = dict()
|
||||||
skip_articles = set()
|
skip_articles = set()
|
||||||
total_entities = 0
|
total_entities = 0
|
||||||
|
@ -302,8 +301,6 @@ def read_training(nlp, training_dir, dev, limit):
|
||||||
with open(entityfile_loc, mode='r', encoding='utf8') as file:
|
with open(entityfile_loc, mode='r', encoding='utf8') as file:
|
||||||
for line in file:
|
for line in file:
|
||||||
if not limit or len(data) < limit:
|
if not limit or len(data) < limit:
|
||||||
if len(data) > 0 and len(data) % 50 == 0:
|
|
||||||
print("Read", total_entities, "entities in", len(data), "articles")
|
|
||||||
fields = line.replace('\n', "").split(sep='|')
|
fields = line.replace('\n', "").split(sep='|')
|
||||||
article_id = fields[0]
|
article_id = fields[0]
|
||||||
alias = fields[1]
|
alias = fields[1]
|
||||||
|
@ -313,34 +310,42 @@ def read_training(nlp, training_dir, dev, limit):
|
||||||
|
|
||||||
if dev == is_dev(article_id) and article_id != "article_id" and article_id not in skip_articles:
|
if dev == is_dev(article_id) and article_id != "article_id" and article_id not in skip_articles:
|
||||||
if not current_doc or (current_article_id != article_id):
|
if not current_doc or (current_article_id != article_id):
|
||||||
# store the data from the previous article
|
|
||||||
if gold_entities and current_doc:
|
|
||||||
gold = GoldParse(doc=current_doc, links=gold_entities)
|
|
||||||
data.append((current_doc, gold))
|
|
||||||
total_entities += len(gold_entities)
|
|
||||||
|
|
||||||
# parse the new article text
|
# parse the new article text
|
||||||
file_name = article_id + ".txt"
|
file_name = article_id + ".txt"
|
||||||
try:
|
try:
|
||||||
with open(os.path.join(training_dir, file_name), mode="r", encoding='utf8') as f:
|
with open(os.path.join(training_dir, file_name), mode="r", encoding='utf8') as f:
|
||||||
text = f.read()
|
text = f.read()
|
||||||
|
if len(text) < 30000: # threshold for convenience / speed of processing
|
||||||
current_doc = nlp(text)
|
current_doc = nlp(text)
|
||||||
|
current_article_id = article_id
|
||||||
|
ents_by_offset = dict()
|
||||||
for ent in current_doc.ents:
|
for ent in current_doc.ents:
|
||||||
ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)] = ent.text
|
ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)] = ent
|
||||||
|
else:
|
||||||
|
skip_articles.add(current_article_id)
|
||||||
|
current_doc = None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Problem parsing article", article_id, e)
|
print("Problem parsing article", article_id, e)
|
||||||
|
|
||||||
current_article_id = article_id
|
|
||||||
gold_entities = list()
|
|
||||||
|
|
||||||
# repeat checking this condition in case an exception was thrown
|
# repeat checking this condition in case an exception was thrown
|
||||||
if current_doc and (current_article_id == article_id):
|
if current_doc and (current_article_id == article_id):
|
||||||
found_ent = ents_by_offset.get(start + "_" + end, None)
|
found_ent = ents_by_offset.get(start + "_" + end, None)
|
||||||
if found_ent:
|
if found_ent:
|
||||||
if found_ent != alias:
|
if found_ent.text != alias:
|
||||||
skip_articles.add(current_article_id)
|
skip_articles.add(current_article_id)
|
||||||
|
current_doc = None
|
||||||
else:
|
else:
|
||||||
gold_entities.append((int(start), int(end), wp_title))
|
sent = found_ent.sent.as_doc()
|
||||||
|
# currently feeding the gold data one entity per sentence at a time
|
||||||
|
gold_start = int(start) - found_ent.sent.start_char
|
||||||
|
gold_end = int(end) - found_ent.sent.start_char
|
||||||
|
gold_entities = list()
|
||||||
|
gold_entities.append((gold_start, gold_end, wp_title))
|
||||||
|
gold = GoldParse(doc=current_doc, links=gold_entities)
|
||||||
|
data.append((sent, gold))
|
||||||
|
total_entities += 1
|
||||||
|
if len(data) % 500 == 0:
|
||||||
|
print(" -read", total_entities, "entities")
|
||||||
|
|
||||||
print("Read", total_entities, "entities in", len(data), "articles")
|
print(" -read", total_entities, "entities")
|
||||||
return data
|
return data
|
||||||
|
|
|
@ -9,7 +9,6 @@ from examples.pipeline.wiki_entity_linking import wikipedia_processor as wp, kb_
|
||||||
from examples.pipeline.wiki_entity_linking.kb_creator import DESC_WIDTH
|
from examples.pipeline.wiki_entity_linking.kb_creator import DESC_WIDTH
|
||||||
|
|
||||||
import spacy
|
import spacy
|
||||||
from spacy.vocab import Vocab
|
|
||||||
from spacy.kb import KnowledgeBase
|
from spacy.kb import KnowledgeBase
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
|
@ -64,8 +63,8 @@ def run_pipeline():
|
||||||
to_test_pipeline = True
|
to_test_pipeline = True
|
||||||
|
|
||||||
# write the NLP object, read back in and test again
|
# write the NLP object, read back in and test again
|
||||||
to_write_nlp = True
|
to_write_nlp = False
|
||||||
to_read_nlp = True
|
to_read_nlp = False
|
||||||
|
|
||||||
# STEP 1 : create prior probabilities from WP
|
# STEP 1 : create prior probabilities from WP
|
||||||
# run only once !
|
# run only once !
|
||||||
|
@ -134,8 +133,8 @@ def run_pipeline():
|
||||||
|
|
||||||
if train_pipe:
|
if train_pipe:
|
||||||
print("STEP 6: training Entity Linking pipe", datetime.datetime.now())
|
print("STEP 6: training Entity Linking pipe", datetime.datetime.now())
|
||||||
train_limit = 5
|
train_limit = 25000
|
||||||
dev_limit = 2
|
dev_limit = 1000
|
||||||
|
|
||||||
train_data = training_set_creator.read_training(nlp=nlp_2,
|
train_data = training_set_creator.read_training(nlp=nlp_2,
|
||||||
training_dir=TRAINING_DIR,
|
training_dir=TRAINING_DIR,
|
||||||
|
@ -345,7 +344,11 @@ def calculate_acc(correct_by_label, incorrect_by_label):
|
||||||
acc_by_label = dict()
|
acc_by_label = dict()
|
||||||
total_correct = 0
|
total_correct = 0
|
||||||
total_incorrect = 0
|
total_incorrect = 0
|
||||||
for label, correct in correct_by_label.items():
|
all_keys = set()
|
||||||
|
all_keys.update(correct_by_label.keys())
|
||||||
|
all_keys.update(incorrect_by_label.keys())
|
||||||
|
for label in sorted(all_keys):
|
||||||
|
correct = correct_by_label.get(label, 0)
|
||||||
incorrect = incorrect_by_label.get(label, 0)
|
incorrect = incorrect_by_label.get(label, 0)
|
||||||
total_correct += correct
|
total_correct += correct
|
||||||
total_incorrect += incorrect
|
total_incorrect += incorrect
|
||||||
|
|
|
@ -1079,36 +1079,39 @@ class EntityLinker(Pipe):
|
||||||
|
|
||||||
embed_width = cfg.get("embed_width", 300)
|
embed_width = cfg.get("embed_width", 300)
|
||||||
hidden_width = cfg.get("hidden_width", 32)
|
hidden_width = cfg.get("hidden_width", 32)
|
||||||
article_width = cfg.get("article_width", 128)
|
|
||||||
sent_width = cfg.get("sent_width", 64)
|
|
||||||
entity_width = cfg.get("entity_width") # no default because this needs to correspond with the KB
|
entity_width = cfg.get("entity_width") # no default because this needs to correspond with the KB
|
||||||
|
sent_width = entity_width
|
||||||
|
|
||||||
article_encoder = build_nel_encoder(in_width=embed_width, hidden_width=hidden_width, end_width=article_width, **cfg)
|
model = build_nel_encoder(in_width=embed_width, hidden_width=hidden_width, end_width=sent_width, **cfg)
|
||||||
sent_encoder = build_nel_encoder(in_width=embed_width, hidden_width=hidden_width, end_width=sent_width, **cfg)
|
|
||||||
|
|
||||||
# dimension of the mention encoder needs to match the dimension of the entity encoder
|
# dimension of the mention encoder needs to match the dimension of the entity encoder
|
||||||
mention_width = article_width + sent_width
|
# article_width = cfg.get("article_width", 128)
|
||||||
mention_encoder = Affine(entity_width, mention_width, drop_factor=0.0)
|
# sent_width = cfg.get("sent_width", 64)
|
||||||
|
# article_encoder = build_nel_encoder(in_width=embed_width, hidden_width=hidden_width, end_width=article_width, **cfg)
|
||||||
|
# mention_width = article_width + sent_width
|
||||||
|
# mention_encoder = Affine(entity_width, mention_width, drop_factor=0.0)
|
||||||
|
# return article_encoder, sent_encoder, mention_encoder
|
||||||
|
|
||||||
return article_encoder, sent_encoder, mention_encoder
|
return model
|
||||||
|
|
||||||
def __init__(self, **cfg):
|
def __init__(self, **cfg):
|
||||||
self.article_encoder = True
|
# self.article_encoder = True
|
||||||
self.sent_encoder = True
|
# self.sent_encoder = True
|
||||||
self.mention_encoder = True
|
# self.mention_encoder = True
|
||||||
|
self.model = True
|
||||||
self.kb = None
|
self.kb = None
|
||||||
self.cfg = dict(cfg)
|
self.cfg = dict(cfg)
|
||||||
self.doc_cutoff = self.cfg.get("doc_cutoff", 5)
|
self.doc_cutoff = self.cfg.get("doc_cutoff", 5)
|
||||||
self.sgd_article = None
|
# self.sgd_article = None
|
||||||
self.sgd_sent = None
|
# self.sgd_sent = None
|
||||||
self.sgd_mention = None
|
# self.sgd_mention = None
|
||||||
|
|
||||||
def set_kb(self, kb):
|
def set_kb(self, kb):
|
||||||
self.kb = kb
|
self.kb = kb
|
||||||
|
|
||||||
def require_model(self):
|
def require_model(self):
|
||||||
# Raise an error if the component's model is not initialized.
|
# Raise an error if the component's model is not initialized.
|
||||||
if getattr(self, "mention_encoder", None) in (None, True, False):
|
if getattr(self, "model", None) in (None, True, False):
|
||||||
raise ValueError(Errors.E109.format(name=self.name))
|
raise ValueError(Errors.E109.format(name=self.name))
|
||||||
|
|
||||||
def require_kb(self):
|
def require_kb(self):
|
||||||
|
@ -1121,12 +1124,19 @@ class EntityLinker(Pipe):
|
||||||
self.require_kb()
|
self.require_kb()
|
||||||
self.cfg["entity_width"] = self.kb.entity_vector_length
|
self.cfg["entity_width"] = self.kb.entity_vector_length
|
||||||
|
|
||||||
if self.mention_encoder is True:
|
if self.model is True:
|
||||||
self.article_encoder, self.sent_encoder, self.mention_encoder = self.Model(**self.cfg)
|
self.model = self.Model(**self.cfg)
|
||||||
self.sgd_article = create_default_optimizer(self.article_encoder.ops)
|
|
||||||
self.sgd_sent = create_default_optimizer(self.sent_encoder.ops)
|
if sgd is None:
|
||||||
self.sgd_mention = create_default_optimizer(self.mention_encoder.ops)
|
sgd = self.create_optimizer()
|
||||||
return self.sgd_article
|
return sgd
|
||||||
|
|
||||||
|
# if self.mention_encoder is True:
|
||||||
|
# self.article_encoder, self.sent_encoder, self.mention_encoder = self.Model(**self.cfg)
|
||||||
|
# self.sgd_article = create_default_optimizer(self.article_encoder.ops)
|
||||||
|
# self.sgd_sent = create_default_optimizer(self.sent_encoder.ops)
|
||||||
|
# self.sgd_mention = create_default_optimizer(self.mention_encoder.ops)
|
||||||
|
# return self.sgd_article
|
||||||
|
|
||||||
def update(self, docs, golds, state=None, drop=0.0, sgd=None, losses=None):
|
def update(self, docs, golds, state=None, drop=0.0, sgd=None, losses=None):
|
||||||
self.require_model()
|
self.require_model()
|
||||||
|
@ -1146,7 +1156,7 @@ class EntityLinker(Pipe):
|
||||||
docs = [docs]
|
docs = [docs]
|
||||||
golds = [golds]
|
golds = [golds]
|
||||||
|
|
||||||
article_docs = list()
|
# article_docs = list()
|
||||||
sentence_docs = list()
|
sentence_docs = list()
|
||||||
entity_encodings = list()
|
entity_encodings = list()
|
||||||
|
|
||||||
|
@ -1173,34 +1183,32 @@ class EntityLinker(Pipe):
|
||||||
if kb_id == gold_kb:
|
if kb_id == gold_kb:
|
||||||
prior_prob = c.prior_prob
|
prior_prob = c.prior_prob
|
||||||
entity_encoding = c.entity_vector
|
entity_encoding = c.entity_vector
|
||||||
|
|
||||||
entity_encodings.append(entity_encoding)
|
entity_encodings.append(entity_encoding)
|
||||||
article_docs.append(first_par)
|
# article_docs.append(first_par)
|
||||||
sentence_docs.append(sentence)
|
sentence_docs.append(sentence)
|
||||||
|
|
||||||
if len(entity_encodings) > 0:
|
if len(entity_encodings) > 0:
|
||||||
doc_encodings, bp_doc = self.article_encoder.begin_update(article_docs, drop=drop)
|
# doc_encodings, bp_doc = self.article_encoder.begin_update(article_docs, drop=drop)
|
||||||
sent_encodings, bp_sent = self.sent_encoder.begin_update(sentence_docs, drop=drop)
|
# sent_encodings, bp_sent = self.sent_encoder.begin_update(sentence_docs, drop=drop)
|
||||||
|
|
||||||
concat_encodings = [list(doc_encodings[i]) + list(sent_encodings[i]) for i in
|
# concat_encodings = [list(doc_encodings[i]) + list(sent_encodings[i]) for i in range(len(article_docs))]
|
||||||
range(len(article_docs))]
|
# mention_encodings, bp_mention = self.mention_encoder.begin_update(np.asarray(concat_encodings), drop=drop)
|
||||||
mention_encodings, bp_mention = self.mention_encoder.begin_update(np.asarray(concat_encodings), drop=drop)
|
|
||||||
|
|
||||||
|
sent_encodings, bp_sent = self.model.begin_update(sentence_docs, drop=drop)
|
||||||
entity_encodings = np.asarray(entity_encodings, dtype=np.float32)
|
entity_encodings = np.asarray(entity_encodings, dtype=np.float32)
|
||||||
|
|
||||||
loss, d_scores = self.get_loss(scores=mention_encodings, golds=entity_encodings, docs=None)
|
loss, d_scores = self.get_loss(scores=sent_encodings, golds=entity_encodings, docs=None)
|
||||||
mention_gradient = bp_mention(d_scores, sgd=self.sgd_mention)
|
bp_sent(d_scores, sgd=sgd)
|
||||||
|
|
||||||
# gradient : concat (doc+sent) vs. desc
|
# gradient : concat (doc+sent) vs. desc
|
||||||
sent_start = self.article_encoder.nO
|
# sent_start = self.article_encoder.nO
|
||||||
sent_gradients = list()
|
# sent_gradients = list()
|
||||||
doc_gradients = list()
|
# doc_gradients = list()
|
||||||
for x in mention_gradient:
|
# for x in mention_gradient:
|
||||||
doc_gradients.append(list(x[0:sent_start]))
|
# doc_gradients.append(list(x[0:sent_start]))
|
||||||
sent_gradients.append(list(x[sent_start:]))
|
# sent_gradients.append(list(x[sent_start:]))
|
||||||
|
# bp_doc(doc_gradients, sgd=self.sgd_article)
|
||||||
bp_doc(doc_gradients, sgd=self.sgd_article)
|
# bp_sent(sent_gradients, sgd=self.sgd_sent)
|
||||||
bp_sent(sent_gradients, sgd=self.sgd_sent)
|
|
||||||
|
|
||||||
if losses is not None:
|
if losses is not None:
|
||||||
losses[self.name] += loss
|
losses[self.name] += loss
|
||||||
|
@ -1262,14 +1270,17 @@ class EntityLinker(Pipe):
|
||||||
first_par_end = sent.end
|
first_par_end = sent.end
|
||||||
first_par = doc[0:first_par_end].as_doc()
|
first_par = doc[0:first_par_end].as_doc()
|
||||||
|
|
||||||
doc_encoding = self.article_encoder([first_par])
|
# doc_encoding = self.article_encoder([first_par])
|
||||||
for ent in doc.ents:
|
for ent in doc.ents:
|
||||||
sent_doc = ent.sent.as_doc()
|
sent_doc = ent.sent.as_doc()
|
||||||
if len(sent_doc) > 0:
|
if len(sent_doc) > 0:
|
||||||
sent_encoding = self.sent_encoder([sent_doc])
|
# sent_encoding = self.sent_encoder([sent_doc])
|
||||||
concat_encoding = [list(doc_encoding[0]) + list(sent_encoding[0])]
|
# concat_encoding = [list(doc_encoding[0]) + list(sent_encoding[0])]
|
||||||
mention_encoding = self.mention_encoder(np.asarray([concat_encoding[0]]))
|
# mention_encoding = self.mention_encoder(np.asarray([concat_encoding[0]]))
|
||||||
mention_enc_t = np.transpose(mention_encoding)
|
# mention_enc_t = np.transpose(mention_encoding)
|
||||||
|
|
||||||
|
sent_encoding = self.model([sent_doc])
|
||||||
|
sent_enc_t = np.transpose(sent_encoding)
|
||||||
|
|
||||||
candidates = self.kb.get_candidates(ent.text)
|
candidates = self.kb.get_candidates(ent.text)
|
||||||
if candidates:
|
if candidates:
|
||||||
|
@ -1278,7 +1289,7 @@ class EntityLinker(Pipe):
|
||||||
prior_prob = c.prior_prob * self.prior_weight
|
prior_prob = c.prior_prob * self.prior_weight
|
||||||
kb_id = c.entity_
|
kb_id = c.entity_
|
||||||
entity_encoding = c.entity_vector
|
entity_encoding = c.entity_vector
|
||||||
sim = float(cosine(np.asarray([entity_encoding]), mention_enc_t)) * self.context_weight
|
sim = float(cosine(np.asarray([entity_encoding]), sent_enc_t)) * self.context_weight
|
||||||
score = prior_prob + sim - (prior_prob*sim) # put weights on the different factors ?
|
score = prior_prob + sim - (prior_prob*sim) # put weights on the different factors ?
|
||||||
scores.append(score)
|
scores.append(score)
|
||||||
|
|
||||||
|
@ -1299,34 +1310,20 @@ class EntityLinker(Pipe):
|
||||||
serialize = OrderedDict()
|
serialize = OrderedDict()
|
||||||
serialize["cfg"] = lambda p: srsly.write_json(p, self.cfg)
|
serialize["cfg"] = lambda p: srsly.write_json(p, self.cfg)
|
||||||
serialize["kb"] = lambda p: self.kb.dump(p)
|
serialize["kb"] = lambda p: self.kb.dump(p)
|
||||||
if self.mention_encoder not in (None, True, False):
|
if self.model not in (None, True, False):
|
||||||
serialize["article_encoder"] = lambda p: p.open("wb").write(self.article_encoder.to_bytes())
|
serialize["model"] = lambda p: p.open("wb").write(self.model.to_bytes())
|
||||||
serialize["sent_encoder"] = lambda p: p.open("wb").write(self.sent_encoder.to_bytes())
|
|
||||||
serialize["mention_encoder"] = lambda p: p.open("wb").write(self.mention_encoder.to_bytes())
|
|
||||||
exclude = util.get_serialization_exclude(serialize, exclude, kwargs)
|
exclude = util.get_serialization_exclude(serialize, exclude, kwargs)
|
||||||
util.to_disk(path, serialize, exclude)
|
util.to_disk(path, serialize, exclude)
|
||||||
|
|
||||||
def from_disk(self, path, exclude=tuple(), **kwargs):
|
def from_disk(self, path, exclude=tuple(), **kwargs):
|
||||||
def load_article_encoder(p):
|
def load_model(p):
|
||||||
if self.article_encoder is True:
|
if self.model is True:
|
||||||
self.article_encoder, _, _ = self.Model(**self.cfg)
|
self.model = self.Model(**self.cfg)
|
||||||
self.article_encoder.from_bytes(p.open("rb").read())
|
self.model.from_bytes(p.open("rb").read())
|
||||||
|
|
||||||
def load_sent_encoder(p):
|
|
||||||
if self.sent_encoder is True:
|
|
||||||
_, self.sent_encoder, _ = self.Model(**self.cfg)
|
|
||||||
self.sent_encoder.from_bytes(p.open("rb").read())
|
|
||||||
|
|
||||||
def load_mention_encoder(p):
|
|
||||||
if self.mention_encoder is True:
|
|
||||||
_, _, self.mention_encoder = self.Model(**self.cfg)
|
|
||||||
self.mention_encoder.from_bytes(p.open("rb").read())
|
|
||||||
|
|
||||||
deserialize = OrderedDict()
|
deserialize = OrderedDict()
|
||||||
deserialize["cfg"] = lambda p: self.cfg.update(_load_cfg(p))
|
deserialize["cfg"] = lambda p: self.cfg.update(_load_cfg(p))
|
||||||
deserialize["article_encoder"] = load_article_encoder
|
deserialize["model"] = load_model
|
||||||
deserialize["sent_encoder"] = load_sent_encoder
|
|
||||||
deserialize["mention_encoder"] = load_mention_encoder
|
|
||||||
exclude = util.get_serialization_exclude(deserialize, exclude, kwargs)
|
exclude = util.get_serialization_exclude(deserialize, exclude, kwargs)
|
||||||
util.from_disk(path, deserialize, exclude)
|
util.from_disk(path, deserialize, exclude)
|
||||||
return self
|
return self
|
||||||
|
|
Loading…
Reference in New Issue
Block a user