mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 09:14:32 +03:00
sentence encoder only (removing article/mention encoder)
This commit is contained in:
parent
6332af40de
commit
ffae7d3555
|
@ -294,7 +294,6 @@ def read_training(nlp, training_dir, dev, limit):
|
|||
# we assume the data is written sequentially
|
||||
current_article_id = None
|
||||
current_doc = None
|
||||
gold_entities = list()
|
||||
ents_by_offset = dict()
|
||||
skip_articles = set()
|
||||
total_entities = 0
|
||||
|
@ -302,8 +301,6 @@ def read_training(nlp, training_dir, dev, limit):
|
|||
with open(entityfile_loc, mode='r', encoding='utf8') as file:
|
||||
for line in file:
|
||||
if not limit or len(data) < limit:
|
||||
if len(data) > 0 and len(data) % 50 == 0:
|
||||
print("Read", total_entities, "entities in", len(data), "articles")
|
||||
fields = line.replace('\n', "").split(sep='|')
|
||||
article_id = fields[0]
|
||||
alias = fields[1]
|
||||
|
@ -313,34 +310,42 @@ def read_training(nlp, training_dir, dev, limit):
|
|||
|
||||
if dev == is_dev(article_id) and article_id != "article_id" and article_id not in skip_articles:
|
||||
if not current_doc or (current_article_id != article_id):
|
||||
# store the data from the previous article
|
||||
if gold_entities and current_doc:
|
||||
gold = GoldParse(doc=current_doc, links=gold_entities)
|
||||
data.append((current_doc, gold))
|
||||
total_entities += len(gold_entities)
|
||||
|
||||
# parse the new article text
|
||||
file_name = article_id + ".txt"
|
||||
try:
|
||||
with open(os.path.join(training_dir, file_name), mode="r", encoding='utf8') as f:
|
||||
text = f.read()
|
||||
current_doc = nlp(text)
|
||||
for ent in current_doc.ents:
|
||||
ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)] = ent.text
|
||||
if len(text) < 30000: # threshold for convenience / speed of processing
|
||||
current_doc = nlp(text)
|
||||
current_article_id = article_id
|
||||
ents_by_offset = dict()
|
||||
for ent in current_doc.ents:
|
||||
ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)] = ent
|
||||
else:
|
||||
skip_articles.add(current_article_id)
|
||||
current_doc = None
|
||||
except Exception as e:
|
||||
print("Problem parsing article", article_id, e)
|
||||
|
||||
current_article_id = article_id
|
||||
gold_entities = list()
|
||||
|
||||
# repeat checking this condition in case an exception was thrown
|
||||
if current_doc and (current_article_id == article_id):
|
||||
found_ent = ents_by_offset.get(start + "_" + end, None)
|
||||
if found_ent:
|
||||
if found_ent != alias:
|
||||
if found_ent.text != alias:
|
||||
skip_articles.add(current_article_id)
|
||||
current_doc = None
|
||||
else:
|
||||
gold_entities.append((int(start), int(end), wp_title))
|
||||
sent = found_ent.sent.as_doc()
|
||||
# currently feeding the gold data one entity per sentence at a time
|
||||
gold_start = int(start) - found_ent.sent.start_char
|
||||
gold_end = int(end) - found_ent.sent.start_char
|
||||
gold_entities = list()
|
||||
gold_entities.append((gold_start, gold_end, wp_title))
|
||||
gold = GoldParse(doc=current_doc, links=gold_entities)
|
||||
data.append((sent, gold))
|
||||
total_entities += 1
|
||||
if len(data) % 500 == 0:
|
||||
print(" -read", total_entities, "entities")
|
||||
|
||||
print("Read", total_entities, "entities in", len(data), "articles")
|
||||
print(" -read", total_entities, "entities")
|
||||
return data
|
||||
|
|
|
@ -9,7 +9,6 @@ from examples.pipeline.wiki_entity_linking import wikipedia_processor as wp, kb_
|
|||
from examples.pipeline.wiki_entity_linking.kb_creator import DESC_WIDTH
|
||||
|
||||
import spacy
|
||||
from spacy.vocab import Vocab
|
||||
from spacy.kb import KnowledgeBase
|
||||
import datetime
|
||||
|
||||
|
@ -64,8 +63,8 @@ def run_pipeline():
|
|||
to_test_pipeline = True
|
||||
|
||||
# write the NLP object, read back in and test again
|
||||
to_write_nlp = True
|
||||
to_read_nlp = True
|
||||
to_write_nlp = False
|
||||
to_read_nlp = False
|
||||
|
||||
# STEP 1 : create prior probabilities from WP
|
||||
# run only once !
|
||||
|
@ -134,8 +133,8 @@ def run_pipeline():
|
|||
|
||||
if train_pipe:
|
||||
print("STEP 6: training Entity Linking pipe", datetime.datetime.now())
|
||||
train_limit = 5
|
||||
dev_limit = 2
|
||||
train_limit = 25000
|
||||
dev_limit = 1000
|
||||
|
||||
train_data = training_set_creator.read_training(nlp=nlp_2,
|
||||
training_dir=TRAINING_DIR,
|
||||
|
@ -345,7 +344,11 @@ def calculate_acc(correct_by_label, incorrect_by_label):
|
|||
acc_by_label = dict()
|
||||
total_correct = 0
|
||||
total_incorrect = 0
|
||||
for label, correct in correct_by_label.items():
|
||||
all_keys = set()
|
||||
all_keys.update(correct_by_label.keys())
|
||||
all_keys.update(incorrect_by_label.keys())
|
||||
for label in sorted(all_keys):
|
||||
correct = correct_by_label.get(label, 0)
|
||||
incorrect = incorrect_by_label.get(label, 0)
|
||||
total_correct += correct
|
||||
total_incorrect += incorrect
|
||||
|
|
|
@ -1079,36 +1079,39 @@ class EntityLinker(Pipe):
|
|||
|
||||
embed_width = cfg.get("embed_width", 300)
|
||||
hidden_width = cfg.get("hidden_width", 32)
|
||||
article_width = cfg.get("article_width", 128)
|
||||
sent_width = cfg.get("sent_width", 64)
|
||||
entity_width = cfg.get("entity_width") # no default because this needs to correspond with the KB
|
||||
sent_width = entity_width
|
||||
|
||||
article_encoder = build_nel_encoder(in_width=embed_width, hidden_width=hidden_width, end_width=article_width, **cfg)
|
||||
sent_encoder = build_nel_encoder(in_width=embed_width, hidden_width=hidden_width, end_width=sent_width, **cfg)
|
||||
model = build_nel_encoder(in_width=embed_width, hidden_width=hidden_width, end_width=sent_width, **cfg)
|
||||
|
||||
# dimension of the mention encoder needs to match the dimension of the entity encoder
|
||||
mention_width = article_width + sent_width
|
||||
mention_encoder = Affine(entity_width, mention_width, drop_factor=0.0)
|
||||
# article_width = cfg.get("article_width", 128)
|
||||
# sent_width = cfg.get("sent_width", 64)
|
||||
# article_encoder = build_nel_encoder(in_width=embed_width, hidden_width=hidden_width, end_width=article_width, **cfg)
|
||||
# mention_width = article_width + sent_width
|
||||
# mention_encoder = Affine(entity_width, mention_width, drop_factor=0.0)
|
||||
# return article_encoder, sent_encoder, mention_encoder
|
||||
|
||||
return article_encoder, sent_encoder, mention_encoder
|
||||
return model
|
||||
|
||||
def __init__(self, **cfg):
|
||||
self.article_encoder = True
|
||||
self.sent_encoder = True
|
||||
self.mention_encoder = True
|
||||
# self.article_encoder = True
|
||||
# self.sent_encoder = True
|
||||
# self.mention_encoder = True
|
||||
self.model = True
|
||||
self.kb = None
|
||||
self.cfg = dict(cfg)
|
||||
self.doc_cutoff = self.cfg.get("doc_cutoff", 5)
|
||||
self.sgd_article = None
|
||||
self.sgd_sent = None
|
||||
self.sgd_mention = None
|
||||
# self.sgd_article = None
|
||||
# self.sgd_sent = None
|
||||
# self.sgd_mention = None
|
||||
|
||||
def set_kb(self, kb):
|
||||
self.kb = kb
|
||||
|
||||
def require_model(self):
|
||||
# Raise an error if the component's model is not initialized.
|
||||
if getattr(self, "mention_encoder", None) in (None, True, False):
|
||||
if getattr(self, "model", None) in (None, True, False):
|
||||
raise ValueError(Errors.E109.format(name=self.name))
|
||||
|
||||
def require_kb(self):
|
||||
|
@ -1121,12 +1124,19 @@ class EntityLinker(Pipe):
|
|||
self.require_kb()
|
||||
self.cfg["entity_width"] = self.kb.entity_vector_length
|
||||
|
||||
if self.mention_encoder is True:
|
||||
self.article_encoder, self.sent_encoder, self.mention_encoder = self.Model(**self.cfg)
|
||||
self.sgd_article = create_default_optimizer(self.article_encoder.ops)
|
||||
self.sgd_sent = create_default_optimizer(self.sent_encoder.ops)
|
||||
self.sgd_mention = create_default_optimizer(self.mention_encoder.ops)
|
||||
return self.sgd_article
|
||||
if self.model is True:
|
||||
self.model = self.Model(**self.cfg)
|
||||
|
||||
if sgd is None:
|
||||
sgd = self.create_optimizer()
|
||||
return sgd
|
||||
|
||||
# if self.mention_encoder is True:
|
||||
# self.article_encoder, self.sent_encoder, self.mention_encoder = self.Model(**self.cfg)
|
||||
# self.sgd_article = create_default_optimizer(self.article_encoder.ops)
|
||||
# self.sgd_sent = create_default_optimizer(self.sent_encoder.ops)
|
||||
# self.sgd_mention = create_default_optimizer(self.mention_encoder.ops)
|
||||
# return self.sgd_article
|
||||
|
||||
def update(self, docs, golds, state=None, drop=0.0, sgd=None, losses=None):
|
||||
self.require_model()
|
||||
|
@ -1146,7 +1156,7 @@ class EntityLinker(Pipe):
|
|||
docs = [docs]
|
||||
golds = [golds]
|
||||
|
||||
article_docs = list()
|
||||
# article_docs = list()
|
||||
sentence_docs = list()
|
||||
entity_encodings = list()
|
||||
|
||||
|
@ -1173,34 +1183,32 @@ class EntityLinker(Pipe):
|
|||
if kb_id == gold_kb:
|
||||
prior_prob = c.prior_prob
|
||||
entity_encoding = c.entity_vector
|
||||
|
||||
entity_encodings.append(entity_encoding)
|
||||
article_docs.append(first_par)
|
||||
# article_docs.append(first_par)
|
||||
sentence_docs.append(sentence)
|
||||
|
||||
if len(entity_encodings) > 0:
|
||||
doc_encodings, bp_doc = self.article_encoder.begin_update(article_docs, drop=drop)
|
||||
sent_encodings, bp_sent = self.sent_encoder.begin_update(sentence_docs, drop=drop)
|
||||
# doc_encodings, bp_doc = self.article_encoder.begin_update(article_docs, drop=drop)
|
||||
# sent_encodings, bp_sent = self.sent_encoder.begin_update(sentence_docs, drop=drop)
|
||||
|
||||
concat_encodings = [list(doc_encodings[i]) + list(sent_encodings[i]) for i in
|
||||
range(len(article_docs))]
|
||||
mention_encodings, bp_mention = self.mention_encoder.begin_update(np.asarray(concat_encodings), drop=drop)
|
||||
# concat_encodings = [list(doc_encodings[i]) + list(sent_encodings[i]) for i in range(len(article_docs))]
|
||||
# mention_encodings, bp_mention = self.mention_encoder.begin_update(np.asarray(concat_encodings), drop=drop)
|
||||
|
||||
sent_encodings, bp_sent = self.model.begin_update(sentence_docs, drop=drop)
|
||||
entity_encodings = np.asarray(entity_encodings, dtype=np.float32)
|
||||
|
||||
loss, d_scores = self.get_loss(scores=mention_encodings, golds=entity_encodings, docs=None)
|
||||
mention_gradient = bp_mention(d_scores, sgd=self.sgd_mention)
|
||||
loss, d_scores = self.get_loss(scores=sent_encodings, golds=entity_encodings, docs=None)
|
||||
bp_sent(d_scores, sgd=sgd)
|
||||
|
||||
# gradient : concat (doc+sent) vs. desc
|
||||
sent_start = self.article_encoder.nO
|
||||
sent_gradients = list()
|
||||
doc_gradients = list()
|
||||
for x in mention_gradient:
|
||||
doc_gradients.append(list(x[0:sent_start]))
|
||||
sent_gradients.append(list(x[sent_start:]))
|
||||
|
||||
bp_doc(doc_gradients, sgd=self.sgd_article)
|
||||
bp_sent(sent_gradients, sgd=self.sgd_sent)
|
||||
# sent_start = self.article_encoder.nO
|
||||
# sent_gradients = list()
|
||||
# doc_gradients = list()
|
||||
# for x in mention_gradient:
|
||||
# doc_gradients.append(list(x[0:sent_start]))
|
||||
# sent_gradients.append(list(x[sent_start:]))
|
||||
# bp_doc(doc_gradients, sgd=self.sgd_article)
|
||||
# bp_sent(sent_gradients, sgd=self.sgd_sent)
|
||||
|
||||
if losses is not None:
|
||||
losses[self.name] += loss
|
||||
|
@ -1262,14 +1270,17 @@ class EntityLinker(Pipe):
|
|||
first_par_end = sent.end
|
||||
first_par = doc[0:first_par_end].as_doc()
|
||||
|
||||
doc_encoding = self.article_encoder([first_par])
|
||||
# doc_encoding = self.article_encoder([first_par])
|
||||
for ent in doc.ents:
|
||||
sent_doc = ent.sent.as_doc()
|
||||
if len(sent_doc) > 0:
|
||||
sent_encoding = self.sent_encoder([sent_doc])
|
||||
concat_encoding = [list(doc_encoding[0]) + list(sent_encoding[0])]
|
||||
mention_encoding = self.mention_encoder(np.asarray([concat_encoding[0]]))
|
||||
mention_enc_t = np.transpose(mention_encoding)
|
||||
# sent_encoding = self.sent_encoder([sent_doc])
|
||||
# concat_encoding = [list(doc_encoding[0]) + list(sent_encoding[0])]
|
||||
# mention_encoding = self.mention_encoder(np.asarray([concat_encoding[0]]))
|
||||
# mention_enc_t = np.transpose(mention_encoding)
|
||||
|
||||
sent_encoding = self.model([sent_doc])
|
||||
sent_enc_t = np.transpose(sent_encoding)
|
||||
|
||||
candidates = self.kb.get_candidates(ent.text)
|
||||
if candidates:
|
||||
|
@ -1278,7 +1289,7 @@ class EntityLinker(Pipe):
|
|||
prior_prob = c.prior_prob * self.prior_weight
|
||||
kb_id = c.entity_
|
||||
entity_encoding = c.entity_vector
|
||||
sim = float(cosine(np.asarray([entity_encoding]), mention_enc_t)) * self.context_weight
|
||||
sim = float(cosine(np.asarray([entity_encoding]), sent_enc_t)) * self.context_weight
|
||||
score = prior_prob + sim - (prior_prob*sim) # put weights on the different factors ?
|
||||
scores.append(score)
|
||||
|
||||
|
@ -1299,34 +1310,20 @@ class EntityLinker(Pipe):
|
|||
serialize = OrderedDict()
|
||||
serialize["cfg"] = lambda p: srsly.write_json(p, self.cfg)
|
||||
serialize["kb"] = lambda p: self.kb.dump(p)
|
||||
if self.mention_encoder not in (None, True, False):
|
||||
serialize["article_encoder"] = lambda p: p.open("wb").write(self.article_encoder.to_bytes())
|
||||
serialize["sent_encoder"] = lambda p: p.open("wb").write(self.sent_encoder.to_bytes())
|
||||
serialize["mention_encoder"] = lambda p: p.open("wb").write(self.mention_encoder.to_bytes())
|
||||
if self.model not in (None, True, False):
|
||||
serialize["model"] = lambda p: p.open("wb").write(self.model.to_bytes())
|
||||
exclude = util.get_serialization_exclude(serialize, exclude, kwargs)
|
||||
util.to_disk(path, serialize, exclude)
|
||||
|
||||
def from_disk(self, path, exclude=tuple(), **kwargs):
|
||||
def load_article_encoder(p):
|
||||
if self.article_encoder is True:
|
||||
self.article_encoder, _, _ = self.Model(**self.cfg)
|
||||
self.article_encoder.from_bytes(p.open("rb").read())
|
||||
|
||||
def load_sent_encoder(p):
|
||||
if self.sent_encoder is True:
|
||||
_, self.sent_encoder, _ = self.Model(**self.cfg)
|
||||
self.sent_encoder.from_bytes(p.open("rb").read())
|
||||
|
||||
def load_mention_encoder(p):
|
||||
if self.mention_encoder is True:
|
||||
_, _, self.mention_encoder = self.Model(**self.cfg)
|
||||
self.mention_encoder.from_bytes(p.open("rb").read())
|
||||
def load_model(p):
|
||||
if self.model is True:
|
||||
self.model = self.Model(**self.cfg)
|
||||
self.model.from_bytes(p.open("rb").read())
|
||||
|
||||
deserialize = OrderedDict()
|
||||
deserialize["cfg"] = lambda p: self.cfg.update(_load_cfg(p))
|
||||
deserialize["article_encoder"] = load_article_encoder
|
||||
deserialize["sent_encoder"] = load_sent_encoder
|
||||
deserialize["mention_encoder"] = load_mention_encoder
|
||||
deserialize["model"] = load_model
|
||||
exclude = util.get_serialization_exclude(deserialize, exclude, kwargs)
|
||||
util.from_disk(path, deserialize, exclude)
|
||||
return self
|
||||
|
|
Loading…
Reference in New Issue
Block a user